Skip to content

Commit

Permalink
transpose points
Browse files Browse the repository at this point in the history
  • Loading branch information
mscroggs committed Jul 25, 2024
1 parent 8afdc7d commit 866c9a1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 118 deletions.
206 changes: 104 additions & 102 deletions src/ciarlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,23 @@ impl<T: RlstScalar + MatrixInverse> CiarletElement<T> {
}

let new_pts = if continuity == Continuity::Discontinuous {
println!("A");
let mut new_pts: EntityPoints<T::Real> = [vec![], vec![], vec![], vec![]];
let mut all_pts = rlst_dynamic_array2![T::Real, [npts, tdim]];
let mut all_pts = rlst_dynamic_array2![T::Real, [tdim, npts]];
for (i, pts_i) in interpolation_points.iter().take(tdim).enumerate() {
for _pts in pts_i {
new_pts[i].push(rlst_dynamic_array2![T::Real, [0, tdim]]);
new_pts[i].push(rlst_dynamic_array2![T::Real, [tdim, 0]]);
}
}
let mut row = 0;
let mut col = 0;
for pts_i in interpolation_points.iter() {
for pts in pts_i {
let nrows = pts.shape()[0];
let ncols = pts.shape()[1];
all_pts
.view_mut()
.into_subview([row, 0], [nrows, tdim])
.into_subview([0, col], [tdim, ncols])
.fill_from(pts.view());
row += nrows;
col += ncols;
}
}
new_pts[tdim].push(all_pts);
Expand Down Expand Up @@ -146,8 +147,8 @@ impl<T: RlstScalar + MatrixInverse> CiarletElement<T> {
let mut dof = 0;
for d in 0..4 {
for (e, pts) in new_pts[d].iter().enumerate() {
if pts.shape()[0] > 0 {
let mut table = rlst_dynamic_array3!(T, [1, pdim, pts.shape()[0]]);
if pts.shape()[1] > 0 {
let mut table = rlst_dynamic_array3!(T, [1, pdim, pts.shape()[1]]);
tabulate_legendre_polynomials(
cell_type,
pts,
Expand Down Expand Up @@ -207,9 +208,9 @@ impl<T: RlstScalar + MatrixInverse> CiarletElement<T> {
let mut dof = 0;
for i in 0..4 {
for pts in &new_pts[i] {
let dofs = (dof..dof + pts.shape()[0]).collect::<Vec<_>>();
let dofs = (dof..dof + pts.shape()[1]).collect::<Vec<_>>();
entity_dofs[i].push(dofs);
dof += pts.shape()[0];
dof += pts.shape()[1];
}
}
let connectivity = reference_cell::connectivity(cell_type);
Expand Down Expand Up @@ -307,7 +308,7 @@ impl<T: RlstScalar + MatrixInverse> FiniteElement for CiarletElement<T> {
);

for d in 0..table.shape()[0] {
for p in 0..points.shape()[0] {
for p in 0..points.shape()[1] {
for j in 0..self.value_size {
for b in 0..self.dim {
// data[d, p, b, j] = inner(self.coefficients[b, j, :], table[d, :, p])
Expand Down Expand Up @@ -384,11 +385,11 @@ mod test {
let e = lagrange::create::<f64>(ReferenceCellType::Interval, 0, Continuity::Discontinuous);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 4));
let mut points = rlst_dynamic_array2!(f64, [4, 1]);
let mut points = rlst_dynamic_array2!(f64, [1, 4]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 0.2;
*points.get_mut([2, 0]).unwrap() = 0.4;
*points.get_mut([3, 0]).unwrap() = 1.0;
*points.get_mut([0, 1]).unwrap() = 0.2;
*points.get_mut([0, 2]).unwrap() = 0.4;
*points.get_mut([0, 3]).unwrap() = 1.0;
e.tabulate(&points, 0, &mut data);

for pt in 0..4 {
Expand All @@ -402,21 +403,21 @@ mod test {
let e = lagrange::create::<f64>(ReferenceCellType::Interval, 1, Continuity::Standard);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 4));
let mut points = rlst_dynamic_array2!(f64, [4, 1]);
let mut points = rlst_dynamic_array2!(f64, [1, 4]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 0.2;
*points.get_mut([2, 0]).unwrap() = 0.4;
*points.get_mut([3, 0]).unwrap() = 1.0;
*points.get_mut([0, 1]).unwrap() = 0.2;
*points.get_mut([0, 2]).unwrap() = 0.4;
*points.get_mut([0, 3]).unwrap() = 1.0;
e.tabulate(&points, 0, &mut data);

for pt in 0..4 {
assert_relative_eq!(
*data.get([0, pt, 0, 0]).unwrap(),
1.0 - *points.get([pt, 0]).unwrap()
1.0 - *points.get([0, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 1, 0]).unwrap(),
*points.get([pt, 0]).unwrap()
*points.get([0, pt]).unwrap()
);
}
check_dofs(e);
Expand All @@ -428,19 +429,19 @@ mod test {
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));

let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 0.5;
*points.get_mut([3, 1]).unwrap() = 0.0;
*points.get_mut([4, 0]).unwrap() = 0.0;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.5;
*points.get_mut([5, 1]).unwrap() = 0.5;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 0.5;
*points.get_mut([1, 3]).unwrap() = 0.0;
*points.get_mut([0, 4]).unwrap() = 0.0;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.5;
*points.get_mut([1, 5]).unwrap() = 0.5;

e.tabulate(&points, 0, &mut data);

Expand All @@ -455,33 +456,33 @@ mod test {
let e = lagrange::create::<f64>(ReferenceCellType::Triangle, 1, Continuity::Standard);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));
let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 0.5;
*points.get_mut([3, 1]).unwrap() = 0.0;
*points.get_mut([4, 0]).unwrap() = 0.0;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.5;
*points.get_mut([5, 1]).unwrap() = 0.5;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 0.5;
*points.get_mut([1, 3]).unwrap() = 0.0;
*points.get_mut([0, 4]).unwrap() = 0.0;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.5;
*points.get_mut([1, 5]).unwrap() = 0.5;
e.tabulate(&points, 0, &mut data);

for pt in 0..6 {
assert_relative_eq!(
*data.get([0, pt, 0, 0]).unwrap(),
1.0 - *points.get([pt, 0]).unwrap() - *points.get([pt, 1]).unwrap()
1.0 - *points.get([0, pt]).unwrap() - *points.get([1, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 1, 0]).unwrap(),
*points.get([pt, 0]).unwrap()
*points.get([0, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 2, 0]).unwrap(),
*points.get([pt, 1]).unwrap()
*points.get([1, pt]).unwrap()
);
}
check_dofs(e);
Expand Down Expand Up @@ -551,19 +552,19 @@ mod test {
);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));
let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 0.5;
*points.get_mut([3, 1]).unwrap() = 0.0;
*points.get_mut([4, 0]).unwrap() = 0.0;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.5;
*points.get_mut([5, 1]).unwrap() = 0.5;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 0.5;
*points.get_mut([1, 3]).unwrap() = 0.0;
*points.get_mut([0, 4]).unwrap() = 0.0;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.5;
*points.get_mut([1, 5]).unwrap() = 0.5;
e.tabulate(&points, 0, &mut data);

for pt in 0..6 {
Expand All @@ -577,38 +578,38 @@ mod test {
let e = lagrange::create::<f64>(ReferenceCellType::Quadrilateral, 1, Continuity::Standard);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));
let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 1.0;
*points.get_mut([3, 1]).unwrap() = 1.0;
*points.get_mut([4, 0]).unwrap() = 0.25;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.3;
*points.get_mut([5, 1]).unwrap() = 0.2;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 1.0;
*points.get_mut([1, 3]).unwrap() = 1.0;
*points.get_mut([0, 4]).unwrap() = 0.25;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.3;
*points.get_mut([1, 5]).unwrap() = 0.2;

e.tabulate(&points, 0, &mut data);

for pt in 0..6 {
assert_relative_eq!(
*data.get([0, pt, 0, 0]).unwrap(),
(1.0 - *points.get([pt, 0]).unwrap()) * (1.0 - *points.get([pt, 1]).unwrap())
(1.0 - *points.get([0, pt]).unwrap()) * (1.0 - *points.get([1, pt]).unwrap())
);
assert_relative_eq!(
*data.get([0, pt, 1, 0]).unwrap(),
*points.get([pt, 0]).unwrap() * (1.0 - *points.get([pt, 1]).unwrap())
*points.get([0, pt]).unwrap() * (1.0 - *points.get([1, pt]).unwrap())
);
assert_relative_eq!(
*data.get([0, pt, 2, 0]).unwrap(),
(1.0 - *points.get([pt, 0]).unwrap()) * *points.get([pt, 1]).unwrap()
(1.0 - *points.get([0, pt]).unwrap()) * *points.get([1, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 3, 0]).unwrap(),
*points.get([pt, 0]).unwrap() * *points.get([pt, 1]).unwrap()
*points.get([0, pt]).unwrap() * *points.get([1, pt]).unwrap()
);
}
check_dofs(e);
Expand All @@ -619,24 +620,25 @@ mod test {
let e = lagrange::create::<f64>(ReferenceCellType::Quadrilateral, 2, Continuity::Standard);
assert_eq!(e.value_size(), 1);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));
let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 1.0;
*points.get_mut([3, 1]).unwrap() = 1.0;
*points.get_mut([4, 0]).unwrap() = 0.25;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.3;
*points.get_mut([5, 1]).unwrap() = 0.2;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 1.0;
*points.get_mut([1, 3]).unwrap() = 1.0;
*points.get_mut([0, 4]).unwrap() = 0.25;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.3;
*points.get_mut([1, 5]).unwrap() = 0.2;

e.tabulate(&points, 0, &mut data);

for pt in 0..6 {
let x = *points.get([pt, 0]).unwrap();
let y = *points.get([pt, 1]).unwrap();
let x = *points.get([0, pt]).unwrap();
let y = *points.get([1, pt]).unwrap();
assert_relative_eq!(
*data.get([0, pt, 0, 0]).unwrap(),
(1.0 - x) * (1.0 - 2.0 * x) * (1.0 - y) * (1.0 - 2.0 * y),
Expand Down Expand Up @@ -691,45 +693,45 @@ mod test {
let e = raviart_thomas::create(ReferenceCellType::Triangle, 1, Continuity::Standard);
assert_eq!(e.value_size(), 2);
let mut data = rlst_dynamic_array4!(f64, e.tabulate_array_shape(0, 6));
let mut points = rlst_dynamic_array2!(f64, [6, 2]);
let mut points = rlst_dynamic_array2!(f64, [2, 6]);
*points.get_mut([0, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 0.0;
*points.get_mut([1, 0]).unwrap() = 1.0;
*points.get_mut([1, 0]).unwrap() = 0.0;
*points.get_mut([0, 1]).unwrap() = 1.0;
*points.get_mut([1, 1]).unwrap() = 0.0;
*points.get_mut([2, 0]).unwrap() = 0.0;
*points.get_mut([2, 1]).unwrap() = 1.0;
*points.get_mut([3, 0]).unwrap() = 0.5;
*points.get_mut([3, 1]).unwrap() = 0.0;
*points.get_mut([4, 0]).unwrap() = 0.0;
*points.get_mut([4, 1]).unwrap() = 0.5;
*points.get_mut([5, 0]).unwrap() = 0.5;
*points.get_mut([5, 1]).unwrap() = 0.5;
*points.get_mut([0, 2]).unwrap() = 0.0;
*points.get_mut([1, 2]).unwrap() = 1.0;
*points.get_mut([0, 3]).unwrap() = 0.5;
*points.get_mut([1, 3]).unwrap() = 0.0;
*points.get_mut([0, 4]).unwrap() = 0.0;
*points.get_mut([1, 4]).unwrap() = 0.5;
*points.get_mut([0, 5]).unwrap() = 0.5;
*points.get_mut([1, 5]).unwrap() = 0.5;
e.tabulate(&points, 0, &mut data);

for pt in 0..6 {
assert_relative_eq!(
*data.get([0, pt, 0, 0]).unwrap(),
-*points.get([pt, 0]).unwrap()
-*points.get([0, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 0, 1]).unwrap(),
-*points.get([pt, 1]).unwrap()
-*points.get([1, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 1, 0]).unwrap(),
*points.get([pt, 0]).unwrap() - 1.0
*points.get([0, pt]).unwrap() - 1.0
);
assert_relative_eq!(
*data.get([0, pt, 1, 1]).unwrap(),
*points.get([pt, 1]).unwrap()
*points.get([1, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 2, 0]).unwrap(),
-*points.get([pt, 0]).unwrap()
-*points.get([0, pt]).unwrap()
);
assert_relative_eq!(
*data.get([0, pt, 2, 1]).unwrap(),
1.0 - *points.get([pt, 1]).unwrap()
1.0 - *points.get([1, pt]).unwrap()
);
}
check_dofs(e);
Expand Down
Loading

0 comments on commit 866c9a1

Please sign in to comment.