Skip to content

Commit

Permalink
changed function signature of BasisGenerator::takeSample.
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamer2368 committed Feb 21, 2024
1 parent 882ada3 commit e16e7ba
Show file tree
Hide file tree
Showing 20 changed files with 79 additions and 71 deletions.
2 changes: 1 addition & 1 deletion examples/misc/combine_samples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ int main(int argc, char* argv[])
CAROM::Vector snap_cur(num_rows, true);
for (int col = 0; col < num_cols; col++) {
snap_cur = *snapshots->getColumn(col);
static_basis_generator2->takeSample(snap_cur.getData(), 0.0, false);
static_basis_generator2->takeSample(snap_cur.getData(), false);
}

/*-- Compute SVD and save file --*/
Expand Down
4 changes: 2 additions & 2 deletions examples/prom/dg_advection_global_rom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ int main(int argc, char *argv[])
Vector u_curr(*U);
Vector u_centered(U->Size());
subtract(u_curr, u_init, u_centered);
bool addSample = generator->takeSample(u_centered.GetData(), t, dt);
bool addSample = generator->takeSample(u_centered.GetData());
}

// 11. The merge phase
Expand Down Expand Up @@ -832,7 +832,7 @@ int main(int argc, char *argv[])
Vector u_curr(*U);
Vector u_centered(U->Size());
subtract(u_curr, u_init, u_centered);
bool addSample = generator->takeSample(u_centered.GetData(), t, dt);
bool addSample = generator->takeSample(u_centered.GetData());
}

if (done || ti % vis_steps == 0)
Expand Down
4 changes: 2 additions & 2 deletions examples/prom/dg_advection_local_rom_matrix_interp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ int main(int argc, char *argv[])
Vector u_curr(*U);
Vector u_centered(U->Size());
subtract(u_curr, u_init, u_centered);
bool addSample = generator->takeSample(u_centered.GetData(), t, dt);
bool addSample = generator->takeSample(u_centered.GetData());
}

if (online)
Expand Down Expand Up @@ -932,7 +932,7 @@ int main(int argc, char *argv[])
Vector u_curr(*U);
Vector u_centered(U->Size());
subtract(u_curr, u_init, u_centered);
bool addSample = generator->takeSample(u_centered.GetData(), t, dt);
bool addSample = generator->takeSample(u_centered.GetData());
}

if (done || ti % vis_steps == 0)
Expand Down
2 changes: 1 addition & 1 deletion examples/prom/linear_elasticity_global_rom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ int main(int argc, char* argv[])
// 18. take and write snapshot for ROM
if (offline)
{
bool addSample = generator->takeSample(X.GetData(), 0.0, 0.01);
bool addSample = generator->takeSample(X.GetData());
generator->writeSnapshot();
delete generator;
delete options;
Expand Down
2 changes: 1 addition & 1 deletion examples/prom/maxwell_global_rom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ int main(int argc, char *argv[])
// 18. take and write snapshot for ROM
if (offline)
{
bool addSample = generator->takeSample(X.GetData(), 0.0, 0.01);
bool addSample = generator->takeSample(X.GetData());
generator->writeSnapshot();
delete generator;
delete options;
Expand Down
16 changes: 8 additions & 8 deletions examples/prom/mixed_nonlinear_diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ int main(int argc, char *argv[])
if (sampleW && hyperreduce_source)
{
oper.GetSource(source);
basis_generator_S->takeSample(source.GetData(), t, dt);
basis_generator_S->takeSample(source.GetData());
// TODO: dfdt? In this example, one can implement the exact formula.
// In general, one can use finite differences in time (dpdt is computed that way).
//basis_generator_S->computeNextSampleTime(p.GetData(), dfdt.GetData(), t);
Expand All @@ -1224,21 +1224,21 @@ int main(int argc, char *argv[])
{
oper.CopyDpDt(dpdt);

basis_generator_R->takeSample(p.GetData(), t, dt);
basis_generator_R->takeSample(p.GetData());
basis_generator_R->computeNextSampleTime(p.GetData(), dpdt.GetData(), t);

Vector p_R(p.GetData(), N1);
Vector Mp(N1);
oper.SetParameters(p);
oper.Mult_Mmat(p_R, Mp);
basis_generator_FR->takeSample(Mp.GetData(), t, dt);
basis_generator_FR->takeSample(Mp.GetData());
}

if (sampleW)
{
oper.CopyDpDt_W(dpdt);

basis_generator_W->takeSample(p_W->GetData(), t, dt);
basis_generator_W->takeSample(p_W->GetData());
basis_generator_W->computeNextSampleTime(p_W->GetData(), dpdt.GetData(), t);
}
}
Expand Down Expand Up @@ -1403,13 +1403,13 @@ int main(int argc, char *argv[])
oper.CopyDpDt(dpdt);

// R space
basis_generator_R->takeSample(p.GetData(), t, dt);
basis_generator_R->takeSample(p.GetData());

Vector p_R(p.GetData(), N1);
Vector Mp(N1);
oper.SetParameters(p);
oper.Mult_Mmat(p_R, Mp);
basis_generator_FR->takeSample(Mp.GetData(), t, dt);
basis_generator_FR->takeSample(Mp.GetData());

// Terminate the sampling and write out information.
basis_generator_R->writeSnapshot();
Expand All @@ -1418,14 +1418,14 @@ int main(int argc, char *argv[])
// W space

// TODO: why call computeNextSampleTime if you just do takeSample on every step anyway?
basis_generator_W->takeSample(p_W->GetData(), t, dt);
basis_generator_W->takeSample(p_W->GetData());
basis_generator_W->writeSnapshot();

oper.GetSource(source);

if (hyperreduce_source)
{
basis_generator_S->takeSample(source.GetData(), t, dt);
basis_generator_S->takeSample(source.GetData());
basis_generator_S->writeSnapshot();
}

Expand Down
16 changes: 9 additions & 7 deletions examples/prom/nonlinear_elasticity_global_rom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,23 +1101,25 @@ int main(int argc, char* argv[])
}

// Take samples
// NOTE(kevin): I don't know why this example checks next sample time.
// IncrementalSVD is never turned on in this example and isNextSample is always true.
if (x_base_only == false && basis_generator_v->isNextSample(t))
{
basis_generator_v->takeSample(vx_diff.GetBlock(0), t, dt);
basis_generator_v->takeSample(vx_diff.GetBlock(0));
basis_generator_v->computeNextSampleTime(vx_diff.GetBlock(0),
dvdt.GetData(), t);
basis_generator_H->takeSample(oper.H_sp.GetData(), t, dt);
basis_generator_H->takeSample(oper.H_sp.GetData());
}

if (basis_generator_x->isNextSample(t))
{
basis_generator_x->takeSample(vx_diff.GetBlock(1), t, dt);
basis_generator_x->takeSample(vx_diff.GetBlock(1));
basis_generator_x->computeNextSampleTime(vx_diff.GetBlock(1),
dxdt.GetData(), t);

if (x_base_only == true)
{
basis_generator_H->takeSample(oper.H_sp.GetData(), t, dt);
basis_generator_H->takeSample(oper.H_sp.GetData());
}
}
}
Expand Down Expand Up @@ -1203,16 +1205,16 @@ int main(int argc, char* argv[])
// Take samples
if (x_base_only == false)
{
basis_generator_v->takeSample(vx_diff.GetBlock(0), t, dt);
basis_generator_v->takeSample(vx_diff.GetBlock(0));
basis_generator_v->writeSnapshot();
delete basis_generator_v;
}

basis_generator_H->takeSample(oper.H_sp.GetData(), t, dt);
basis_generator_H->takeSample(oper.H_sp.GetData());
basis_generator_H->writeSnapshot();
delete basis_generator_H;

basis_generator_x->takeSample(vx_diff.GetBlock(1), t, dt);
basis_generator_x->takeSample(vx_diff.GetBlock(1));
basis_generator_x->writeSnapshot();
delete basis_generator_x;

Expand Down
2 changes: 1 addition & 1 deletion examples/prom/poisson_global_rom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ int main(int argc, char *argv[])
// 18. take and write snapshot for ROM
if (offline)
{
bool addSample = generator->takeSample(X.GetData(), 0.0, 0.01);
bool addSample = generator->takeSample(X.GetData());
generator->writeSnapshot();
delete generator;
delete options;
Expand Down
2 changes: 1 addition & 1 deletion examples/prom/poisson_local_rom_greedy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ int main(int argc, char *argv[])
// 19. take and write snapshot for ROM
if (offline)
{
bool addSample = generator->takeSample(X.GetData(), 0.0, 0.01);
bool addSample = generator->takeSample(X.GetData());
generator->writeSnapshot();
basisIdentifiers.push_back(saveBasisName);
delete generator;
Expand Down
14 changes: 12 additions & 2 deletions lib/linalg/BasisGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ BasisGenerator::isNextSample(
bool
BasisGenerator::takeSample(
double* u_in,
double time,
double dt,
bool add_without_increase)
{
CAROM_VERIFY(u_in != 0);
Expand All @@ -143,6 +141,18 @@ BasisGenerator::takeSample(
return false;
}

/*
Note for previous implementation:
Previously with multiple time interval,
there was an input argument (double dt),
which is only used to reset d_dt for new time interval.
Assuming only single interval is used in practice,
resetDt(dt) was never used in takeSample,
and options.initial_dt is used for incremental svd.
*/
// if (d_svd->isNewSample())
// resetDt(dt);

return d_svd->takeSample(u_in, add_without_increase);
}

Expand Down
4 changes: 0 additions & 4 deletions lib/linalg/BasisGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ class BasisGenerator
* @pre time >= 0.0
*
* @param[in] u_in The state at the specified time.
* @param[in] time The simulation time for the state.
* @param[in] dt The current simulation dt.
* @param[in] add_without_increase If true, the addLinearlyDependent is
* invoked. This only applies to incremental
* SVD.
Expand All @@ -113,8 +111,6 @@ class BasisGenerator
bool
takeSample(
double* u_in,
double time,
double dt,
bool add_without_increase = false);

/**
Expand Down
4 changes: 2 additions & 2 deletions unit_tests/random_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ main(
bool status = true;
for (int i = 0; i < num_samples; ++i) {
if (inc_basis_generator.isNextSample(0.01*i)) {
status = inc_basis_generator.takeSample(M[i], 0.01*i, 0.01);
status = inc_basis_generator.takeSample(M[i]);
if (!status) {
break;
}
inc_basis_generator.computeNextSampleTime(M[i], M[i], 0.01*i);
}
if (i < num_lin_indep_samples &&
static_basis_generator.isNextSample(0.01*i)) {
status = static_basis_generator.takeSample(M[i], 0.01*i, 0.01);
status = static_basis_generator.takeSample(M[i]);
if (!status) {
break;
}
Expand Down
8 changes: 4 additions & 4 deletions unit_tests/smoke_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ main(
"static_smoke1"));

// Take the first sample.
static_basis_generator->takeSample(&vals0[dim*rank],0,0.1);
static_basis_generator->takeSample(&vals0[dim*rank]);
std::cout << "Writing sample 1" << std::endl;
static_basis_generator->writeSnapshot();
static_basis_generator->endSamples();
Expand All @@ -87,7 +87,7 @@ main(
"static_smoke2"));

// Take the second sample.
static_basis_generator2->takeSample(&vals1[dim*rank],0,0.1);
static_basis_generator2->takeSample(&vals1[dim*rank]);
static_basis_generator2->writeSnapshot(); // "_snapshot" will be added to the base file name
static_basis_generator2->endSamples();

Expand Down Expand Up @@ -120,8 +120,8 @@ main(
static_svd_options, false,
"static_smoke_check"));

static_basis_generator4->takeSample(&vals0[dim*rank],0,0.1);
static_basis_generator4->takeSample(&vals1[dim*rank],0,0.1);
static_basis_generator4->takeSample(&vals0[dim*rank]);
static_basis_generator4->takeSample(&vals1[dim*rank]);

static_basis_generator4->endSamples();
static_basis_generator4 = nullptr;
Expand Down
4 changes: 2 additions & 2 deletions unit_tests/smoke_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ main(

// Take the first sample.
if (inc_basis_generator.isNextSample(0.0)) {
status = inc_basis_generator.takeSample(&vals0[dim*rank], 0.0, 0.11);
status = inc_basis_generator.takeSample(&vals0[dim*rank]);
if (status) {
inc_basis_generator.computeNextSampleTime(&vals0[dim*rank],
&vals0[dim*rank],
Expand All @@ -83,7 +83,7 @@ main(

// Take the second sample.
if (status && inc_basis_generator.isNextSample(0.11)) {
status = inc_basis_generator.takeSample(&vals1[dim*rank], 0.11, 0.11);
status = inc_basis_generator.takeSample(&vals1[dim*rank]);
if (status) {
inc_basis_generator.computeNextSampleTime(&vals1[dim*rank],
&vals1[dim*rank],
Expand Down
6 changes: 3 additions & 3 deletions unit_tests/test_IncrementalSVDBrand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ TEST(IncrementalSVDBrandTest, Test_IncrementalSVDBrand)
incremental_svd_options,
true,
"irrelevant.txt");
sampler.takeSample(&sample1[row_offset[d_rank]], 0, 1e-1);
sampler.takeSample(&sample2[row_offset[d_rank]], 0, 1e-1);
sampler.takeSample(&sample3[row_offset[d_rank]], 0, 1e-1);
sampler.takeSample(&sample1[row_offset[d_rank]]);
sampler.takeSample(&sample2[row_offset[d_rank]]);
sampler.takeSample(&sample3[row_offset[d_rank]]);

const CAROM::Matrix* d_basis = sampler.getSpatialBasis();
const CAROM::Matrix* d_basis_right = sampler.getTemporalBasis();
Expand Down
32 changes: 16 additions & 16 deletions unit_tests/test_RandomizedSVD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ TEST(RandomizedSVDTest, Test_RandomizedSVD)
randomized_svd_options.setDebugMode(true);
randomized_svd_options.setRandomizedSVD(true);
CAROM::BasisGenerator sampler(randomized_svd_options, false);
sampler.takeSample(&sample1[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample2[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample3[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample1[row_offset[d_rank]]);
sampler.takeSample(&sample2[row_offset[d_rank]]);
sampler.takeSample(&sample3[row_offset[d_rank]]);

const CAROM::Matrix* d_basis = sampler.getSpatialBasis();
const CAROM::Matrix* d_basis_right = sampler.getTemporalBasis();
Expand Down Expand Up @@ -154,11 +154,11 @@ TEST(RandomizedSVDTest, Test_RandomizedSVDTransposed)
randomized_svd_options.setDebugMode(true);
randomized_svd_options.setRandomizedSVD(true);
CAROM::BasisGenerator sampler(randomized_svd_options, false);
sampler.takeSample(&sample1[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample2[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample3[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample4[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample5[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample1[row_offset[d_rank]]);
sampler.takeSample(&sample2[row_offset[d_rank]]);
sampler.takeSample(&sample3[row_offset[d_rank]]);
sampler.takeSample(&sample4[row_offset[d_rank]]);
sampler.takeSample(&sample5[row_offset[d_rank]]);

const CAROM::Matrix* d_basis = sampler.getSpatialBasis();
const CAROM::Matrix* d_basis_right = sampler.getTemporalBasis();
Expand Down Expand Up @@ -232,9 +232,9 @@ TEST(RandomizedSVDTest, Test_RandomizedSVDSmallerSubspace)
randomized_svd_options.setDebugMode(true);
randomized_svd_options.setRandomizedSVD(true, 2);
CAROM::BasisGenerator sampler(randomized_svd_options, false);
sampler.takeSample(&sample1[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample2[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample3[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample1[row_offset[d_rank]]);
sampler.takeSample(&sample2[row_offset[d_rank]]);
sampler.takeSample(&sample3[row_offset[d_rank]]);

const CAROM::Matrix* d_basis = sampler.getSpatialBasis();
const CAROM::Matrix* d_basis_right = sampler.getTemporalBasis();
Expand Down Expand Up @@ -311,11 +311,11 @@ TEST(RandomizedSVDTest, Test_RandomizedSVDTransposedSmallerSubspace)
randomized_svd_options.setDebugMode(true);
randomized_svd_options.setRandomizedSVD(true, reduced_rows);
CAROM::BasisGenerator sampler(randomized_svd_options, false);
sampler.takeSample(&sample1[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample2[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample3[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample4[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample5[row_offset[d_rank]], 0, 0);
sampler.takeSample(&sample1[row_offset[d_rank]]);
sampler.takeSample(&sample2[row_offset[d_rank]]);
sampler.takeSample(&sample3[row_offset[d_rank]]);
sampler.takeSample(&sample4[row_offset[d_rank]]);
sampler.takeSample(&sample5[row_offset[d_rank]]);

const CAROM::Matrix* d_basis = sampler.getSpatialBasis();
const CAROM::Matrix* d_basis_right = sampler.getTemporalBasis();
Expand Down
Loading

0 comments on commit e16e7ba

Please sign in to comment.