Skip to content

Commit

Permalink
Merge pull request #67 from njoy/fix/operations
Browse files Browse the repository at this point in the history
Fix/operations
  • Loading branch information
whaeck authored Feb 14, 2024
2 parents 9c71b98 + fcd77fb commit 1e99fd2
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 75 deletions.
238 changes: 204 additions & 34 deletions python/test/math/Test_scion_math_InterpolationTable.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/scion/math/InterpolationTable/src/evaluateOnGrid.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ std::vector< Y > evaluateOnGrid( const std::vector< X >& x ) const {
std::vector< Y > y( x.size(), Y( 0. ) );

auto xIter = std::lower_bound( x.begin(), x.end(), this->x().front() );
if ( *std::next( xIter ) == this->x().front() ) {

++xIter;
}
auto yIter = std::next( y.begin(), std::distance( x.begin(), xIter ) );

auto xTable = this->x().begin();
Expand Down
2 changes: 0 additions & 2 deletions src/scion/math/InterpolationTable/src/generateTables.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ void generateTables() {

auto xStart = this->x().begin();
auto yStart = this->y().begin();
auto xEnd = xStart;
auto yEnd = yStart;
std::size_t nr = this->boundaries().size();
bool linearised = true;
for ( std::size_t i = 0; i < nr; ++i ) {
Expand Down
33 changes: 15 additions & 18 deletions src/scion/math/InterpolationTable/src/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,27 @@ InterpolationTable& operation( const InterpolationTable& right,
}
else {

// check for threshold tables
if ( this->x().front() != right.x().front() ) {

Y ystart = this->x().front() < right.x().front() ? right.y().front()
: this->y().front();

if ( Y( 0. ) != ystart ) {

X xstart = this->x().front() < right.x().front() ? right.x().front()
: this->x().front();

Log::error( "The threshold table's first y value is not zero" );
Log::info( "Found x = {}", xstart );
Log::info( "Found y = {}", ystart );
throw std::exception();
}
}

// unionise and evaluate on the new grid
std::vector< X > x = unionisation::unionise( this->x(), right.x() );
std::vector< Y > y = this->evaluateOnGrid( x );
std::vector< Y > temp = right.evaluateOnGrid( x );
std::transform( y.begin(), y.end(), temp.begin(), y.begin(), operation );

// check for threshold jump with the same y value
if ( this->x().front() != right.x().front() ) {

X value = this->x().front() < right.x().front() ? right.x().front()
: this->x().front();
auto xIter = std::lower_bound( x.begin(), x.end(), value );
auto yIter = std::next( y.begin(), std::distance( x.begin(), xIter ) );
if ( *std::next( yIter ) == *yIter ) {

x.erase( xIter );
y.erase( yIter );
}
}

// replace this with a new table
*this = InterpolationTable( std::move( x ), std::move( y ) );
}

Expand Down
6 changes: 6 additions & 0 deletions src/scion/math/InterpolationTable/src/processBoundaries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ processBoundaries( const std::vector< X >& x, const std::vector< Y >& y,
}

auto xIter = std::adjacent_find( x.begin(), x.end() );
if ( xIter == x.begin() ) {

Log::error( "A jump in the x grid cannot occur at the beginning of the x grid" );
throw std::exception();
}

auto bIter = boundaries.begin();
auto iIter = interpolants.begin();
while ( xIter != x.end() ) {
Expand Down
219 changes: 205 additions & 14 deletions src/scion/math/InterpolationTable/test/InterpolationTable.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ SCENARIO( "InterpolationTable" ) {
InterpolationTable< double > result( { 1., 4. }, { 0., 0. } );
InterpolationTable< double > same( { 1., 4. }, { 0., 3. } );
InterpolationTable< double > threshold( { 2., 4. }, { 0., 2. } );
InterpolationTable< double > nonzerothreshold( { 2., 4. }, { 1., 2. } );
InterpolationTable< double > nonzerothreshold( { 2., 4. }, { 1., 3. } );
InterpolationTable< double > small( { 1., 3. }, { 0., 2. } );

chunk += 2.;
Expand Down Expand Up @@ -492,6 +492,98 @@ SCENARIO( "InterpolationTable" ) {
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

chunk += nonzerothreshold;

CHECK( 5 == chunk.numberPoints() );
CHECK( 2 == chunk.numberRegions() );
CHECK( 5 == chunk.x().size() );
CHECK( 5 == chunk.y().size() );
CHECK( 2 == chunk.boundaries().size() );
CHECK( 2 == chunk.interpolants().size() );
CHECK_THAT( 1., WithinRel( chunk.x()[0] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[1] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[2] ) );
CHECK_THAT( 3., WithinRel( chunk.x()[3] ) );
CHECK_THAT( 4., WithinRel( chunk.x()[4] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[0] ) );
CHECK_THAT( 3., WithinRel( chunk.y()[1] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[2] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[3] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[4] ) );
CHECK( 1 == chunk.boundaries()[0] );
CHECK( 4 == chunk.boundaries()[1] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[1] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( chunk.domain() ) );

chunk -= nonzerothreshold;

CHECK( 4 == chunk.numberPoints() );
CHECK( 1 == chunk.numberRegions() );
CHECK( 4 == chunk.x().size() );
CHECK( 4 == chunk.y().size() );
CHECK( 1 == chunk.boundaries().size() );
CHECK( 1 == chunk.interpolants().size() );
CHECK_THAT( 1. , WithinRel( chunk.x()[0] ) );
CHECK_THAT( 2. , WithinRel( chunk.x()[1] ) );
CHECK_THAT( 3. , WithinRel( chunk.x()[2] ) );
CHECK_THAT( 4. , WithinRel( chunk.x()[3] ) );
CHECK_THAT( 4.0, WithinRel( chunk.y()[0] ) );
CHECK_THAT( 3.0, WithinRel( chunk.y()[1] ) );
CHECK_THAT( 2.0, WithinRel( chunk.y()[2] ) );
CHECK_THAT( 1.0, WithinRel( chunk.y()[3] ) );
CHECK( 3 == chunk.boundaries()[0] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[0] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( chunk.domain() ) );

result = chunk + nonzerothreshold;

CHECK( 5 == result.numberPoints() );
CHECK( 2 == result.numberRegions() );
CHECK( 5 == result.x().size() );
CHECK( 5 == result.y().size() );
CHECK( 2 == result.boundaries().size() );
CHECK( 2 == result.interpolants().size() );
CHECK_THAT( 1., WithinRel( result.x()[0] ) );
CHECK_THAT( 2., WithinRel( result.x()[1] ) );
CHECK_THAT( 2., WithinRel( result.x()[2] ) );
CHECK_THAT( 3., WithinRel( result.x()[3] ) );
CHECK_THAT( 4., WithinRel( result.x()[4] ) );
CHECK_THAT( 4., WithinRel( result.y()[0] ) );
CHECK_THAT( 3., WithinRel( result.y()[1] ) );
CHECK_THAT( 4., WithinRel( result.y()[2] ) );
CHECK_THAT( 4., WithinRel( result.y()[3] ) );
CHECK_THAT( 4., WithinRel( result.y()[4] ) );
CHECK( 1 == result.boundaries()[0] );
CHECK( 4 == result.boundaries()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

result = chunk - nonzerothreshold;

CHECK( 5 == result.numberPoints() );
CHECK( 2 == result.numberRegions() );
CHECK( 5 == result.x().size() );
CHECK( 5 == result.y().size() );
CHECK( 2 == result.boundaries().size() );
CHECK( 2 == result.interpolants().size() );
CHECK_THAT( 1., WithinRel( result.x()[0] ) );
CHECK_THAT( 2., WithinRel( result.x()[1] ) );
CHECK_THAT( 2., WithinRel( result.x()[2] ) );
CHECK_THAT( 3., WithinRel( result.x()[3] ) );
CHECK_THAT( 4., WithinRel( result.x()[4] ) );
CHECK_THAT( 4., WithinRel( result.y()[0] ) );
CHECK_THAT( 3., WithinRel( result.y()[1] ) );
CHECK_THAT( 2., WithinRel( result.y()[2] ) );
CHECK_THAT( 0., WithinRel( result.y()[3] ) );
CHECK_THAT( -2., WithinRel( result.y()[4] ) );
CHECK( 1 == result.boundaries()[0] );
CHECK( 4 == result.boundaries()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

// this will add a second point at the lower end point
result = chunk + small;

Expand Down Expand Up @@ -539,12 +631,6 @@ SCENARIO( "InterpolationTable" ) {
CHECK( 4 == result.boundaries()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

// the threshold table starts with a non-zero value
CHECK_THROWS( chunk += nonzerothreshold );
CHECK_THROWS( chunk -= nonzerothreshold );
CHECK_THROWS( result = chunk + nonzerothreshold );
CHECK_THROWS( result = chunk - nonzerothreshold );
} // THEN

THEN( "an InterpolationTable can be linearised" ) {
Expand Down Expand Up @@ -687,7 +773,7 @@ SCENARIO( "InterpolationTable" ) {
InterpolationTable< double > result( { 1., 4. }, { 0., 0. } );
InterpolationTable< double > same( { 1., 4. }, { 0., 3. } );
InterpolationTable< double > threshold( { 2., 4. }, { 0., 2. } );
InterpolationTable< double > nonzerothreshold( { 2., 4. }, { 1., 2. } );
InterpolationTable< double > nonzerothreshold( { 3., 4. }, { 1., 2. } );
InterpolationTable< double > small( { 1., 3. }, { 0., 2. } );

chunk += 2.;
Expand Down Expand Up @@ -1130,6 +1216,106 @@ SCENARIO( "InterpolationTable" ) {
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

chunk += nonzerothreshold;

CHECK( 6 == chunk.x().size() );
CHECK( 6 == chunk.y().size() );
CHECK( 3 == chunk.boundaries().size() );
CHECK( 3 == chunk.interpolants().size() );
CHECK_THAT( 1., WithinRel( chunk.x()[0] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[1] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[2] ) );
CHECK_THAT( 3., WithinRel( chunk.x()[3] ) );
CHECK_THAT( 3., WithinRel( chunk.x()[4] ) );
CHECK_THAT( 4., WithinRel( chunk.x()[5] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[0] ) );
CHECK_THAT( 3., WithinRel( chunk.y()[1] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[2] ) );
CHECK_THAT( 3., WithinRel( chunk.y()[3] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[4] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[5] ) );
CHECK( 1 == chunk.boundaries()[0] );
CHECK( 3 == chunk.boundaries()[1] );
CHECK( 5 == chunk.boundaries()[2] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[1] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[2] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( chunk.domain() ) );

chunk -= nonzerothreshold;

CHECK( 5 == chunk.x().size() );
CHECK( 5 == chunk.y().size() );
CHECK( 2 == chunk.boundaries().size() );
CHECK( 2 == chunk.interpolants().size() );
CHECK_THAT( 1., WithinRel( chunk.x()[0] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[1] ) );
CHECK_THAT( 2., WithinRel( chunk.x()[2] ) );
CHECK_THAT( 3., WithinRel( chunk.x()[3] ) );
CHECK_THAT( 4., WithinRel( chunk.x()[4] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[0] ) );
CHECK_THAT( 3., WithinRel( chunk.y()[1] ) );
CHECK_THAT( 4., WithinRel( chunk.y()[2] ) );
CHECK_THAT( 3., WithinRel( chunk.y()[3] ) );
CHECK_THAT( 2., WithinRel( chunk.y()[4] ) );
CHECK( 1 == chunk.boundaries()[0] );
CHECK( 4 == chunk.boundaries()[1] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == chunk.interpolants()[1] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( chunk.domain() ) );

result = chunk + nonzerothreshold;

CHECK( 6 == result.x().size() );
CHECK( 6 == result.y().size() );
CHECK( 3 == result.boundaries().size() );
CHECK( 3 == result.interpolants().size() );
CHECK_THAT( 1., WithinRel( result.x()[0] ) );
CHECK_THAT( 2., WithinRel( result.x()[1] ) );
CHECK_THAT( 2., WithinRel( result.x()[2] ) );
CHECK_THAT( 3., WithinRel( result.x()[3] ) );
CHECK_THAT( 3., WithinRel( result.x()[4] ) );
CHECK_THAT( 4., WithinRel( result.x()[5] ) );
CHECK_THAT( 4., WithinRel( result.y()[0] ) );
CHECK_THAT( 3., WithinRel( result.y()[1] ) );
CHECK_THAT( 4., WithinRel( result.y()[2] ) );
CHECK_THAT( 3., WithinRel( result.y()[3] ) );
CHECK_THAT( 4., WithinRel( result.y()[4] ) );
CHECK_THAT( 4., WithinRel( result.y()[5] ) );
CHECK( 1 == result.boundaries()[0] );
CHECK( 3 == result.boundaries()[1] );
CHECK( 5 == result.boundaries()[2] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[2] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

result = chunk - nonzerothreshold;

CHECK( 6 == result.x().size() );
CHECK( 6 == result.y().size() );
CHECK( 3 == result.boundaries().size() );
CHECK( 3 == result.interpolants().size() );
CHECK_THAT( 1., WithinRel( result.x()[0] ) );
CHECK_THAT( 2., WithinRel( result.x()[1] ) );
CHECK_THAT( 2., WithinRel( result.x()[2] ) );
CHECK_THAT( 3., WithinRel( result.x()[3] ) );
CHECK_THAT( 3., WithinRel( result.x()[4] ) );
CHECK_THAT( 4., WithinRel( result.x()[5] ) );
CHECK_THAT( 4., WithinRel( result.y()[0] ) );
CHECK_THAT( 3., WithinRel( result.y()[1] ) );
CHECK_THAT( 4., WithinRel( result.y()[2] ) );
CHECK_THAT( 3., WithinRel( result.y()[3] ) );
CHECK_THAT( 2., WithinRel( result.y()[4] ) );
CHECK_THAT( 0., WithinRel( result.y()[5] ) );
CHECK( 1 == result.boundaries()[0] );
CHECK( 3 == result.boundaries()[1] );
CHECK( 5 == result.boundaries()[2] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[0] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[2] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

// this will add a second point at the lower end point
result = chunk + small;

Expand Down Expand Up @@ -1183,12 +1369,6 @@ SCENARIO( "InterpolationTable" ) {
CHECK( InterpolationType::LinearLinear == result.interpolants()[1] );
CHECK( InterpolationType::LinearLinear == result.interpolants()[2] );
CHECK( true == std::holds_alternative< IntervalDomain< double > >( result.domain() ) );

// the threshold table starts with a non-zero value
CHECK_THROWS( chunk += nonzerothreshold );
CHECK_THROWS( chunk -= nonzerothreshold );
CHECK_THROWS( result = chunk + nonzerothreshold );
CHECK_THROWS( result = chunk - nonzerothreshold );
} // THEN
} // WHEN
} // GIVEN
Expand Down Expand Up @@ -1569,6 +1749,17 @@ SCENARIO( "InterpolationTable" ) {
} // THEN
} // WHEN

WHEN( "the x grid has a jump at the beginning" ) {

std::vector< double > x = { 1., 1., 3., 4. };
std::vector< double > y = { 4., 3., 1., 4. };

THEN( "an exception is thrown" ) {

CHECK_THROWS( InterpolationTable< double >( std::move( x ), std::move( y ) ) );
} // THEN
} // WHEN

WHEN( "the x grid has a jump at the end" ) {

std::vector< double > x = { 1., 2., 4., 4. };
Expand Down
21 changes: 14 additions & 7 deletions src/scion/unionisation/unionise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,9 @@ namespace unionisation {
/**
* @brief Unionise two grids and preserve duplicate points that appear in each
*
* If the grids do not have the same end point, a duplicate point is inserted
* into the grid corresponding to the lowest end point (unless that is
* already a duplicate point).
*
* No special treatment is performed when the grids do not have the same
* starting point. We assume that the tabulated value at the higher starting
* point will be zero (we can enforce this behaviour when reading the data).
* If the grids do not have the same begin and/or end point, a duplicate point
* is inserted into the grid corresponding to the highest begining and/or
* lowest end point (unless those is already a duplicate point).
*
* @param first the first grid (assumed to be sorted)
* @param second the second grid (assumed to be sorted)
Expand All @@ -35,6 +31,17 @@ namespace unionisation {
grid.begin() );
grid.erase( end, grid.end() );

// special case: the begin points are not the same
if ( first.front() != second.front() ) {

X x = first.front() < second.front() ? second.front() : first.front();
auto iter = std::lower_bound( grid.begin(), grid.end(), x );
if ( *std::next( iter ) != x ) {

grid.insert( iter, x );
}
}

// special case: the end points are not the same
if ( first.back() != second.back() ) {

Expand Down

0 comments on commit 1e99fd2

Please sign in to comment.