Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3.3.0 #189

Merged
merged 74 commits into from
Feb 12, 2025
Merged

3.3.0 #189

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
a6e0bdb
[geom] Fix Cuboid normals
Dec 4, 2024
65395c6
[geom] Fix Mesh.volume
Dec 4, 2024
079638c
[vis] Matplotlib Cuboid fix
Dec 4, 2024
1b50a4b
[geom] Fix Cylinder.bounding_half_extent for non-matching dimensions
Dec 4, 2024
e661387
[Φ] Add *min, *max, *prod convenience imports
Dec 4, 2024
97f0c1d
[geom] Support periodic boundaries with inverse order in Mesh
holl- Dec 5, 2024
e19816a
[geom] Fix multi-periodic Mesh
holl- Dec 6, 2024
9986384
[geom] 3D Mesh elements
holl- Dec 6, 2024
6d537cb
[Φ] Add meshgrid, *pack to convenience imports
holl- Dec 9, 2024
ae66051
[geom] Refactor Cylinder, Mesh to use new PhiML dataclass support
holl- Dec 9, 2024
15fdb8d
[geom] Implement Cylinder.face_centers
Dec 10, 2024
930139b
[vis] Fix plotting higher-order connected points
Dec 11, 2024
0d6fef8
[vis] Fix Plotly quad rendering
Dec 12, 2024
435d4b0
[geom] Add MeshBuilder (experimental)
Dec 13, 2024
15769d3
[geom] Add axis_angle_from_directions()
Dec 13, 2024
5d442a2
[geom] Add axes parameter to mesh_from_numpy()
Dec 13, 2024
3f1f840
[geom] Add Geometry.bounding_sphere()
Dec 13, 2024
f61b50d
[geom] Fix Geometry.__getattr__ to work with copy
Dec 13, 2024
9a3c419
[geom] Fix Cylinder.face_centers
Dec 13, 2024
b37d0bc
[geom] Refactor: move Tensor functions to _functions.py, Box.rotation…
Dec 13, 2024
2b28a1f
[geom] Fix Cylinder.face_centers
holl- Dec 14, 2024
1d256ff
Fix Field.__getitem__ for non-grids
holl- Dec 15, 2024
4eb7e65
[geom] Fix backend in Mesh.build_faces()
holl- Dec 15, 2024
df3fe6f
[demos] Update FVM_Cylinder_GMsh.ipynb
holl- Dec 15, 2024
180a4f8
[geom] Fix surface_mesh() params check
Dec 16, 2024
2216eff
[geom] Fix Geometry.__getattr__
Dec 16, 2024
d3e1e14
[geom] Add is_normalized to closest_normal_vector()
Dec 16, 2024
e960f20
[field] Fix slicing
Dec 17, 2024
afd9ac1
[geom,io] Update to PhiML 1.12
holl- Dec 22, 2024
9cce527
[geom] Remove diagonal from Mesh.element_connectivity for surface meshes
Jan 3, 2025
2070165
[geom] Use trimesh to load STL files
Jan 3, 2025
ebf4143
[geom] Implement Sphere.face_areas
Jan 4, 2025
4647f1e
[geom] Fix _mesh.py imports
Jan 4, 2025
65d0ec2
[geom] Fix Mesh.pad_boundary for periodic meshes
holl- Dec 30, 2024
80983dd
[field] Fix mean() for unstructured meshes
holl- Jan 4, 2025
7943d31
[physics] Detect rank_deficiency for meshes in make_incompressible()
holl- Jan 4, 2025
d7d9939
[Φ] Add *cat, p2d, primal reductions to convenience imports
holl- Jan 5, 2025
8491c8c
[field] Fix Field.bounds, make sliceable
Jan 8, 2025
7fdc723
[geom] Fix Mesh.pad_boundary for periodic meshes
holl- Jan 10, 2025
78b8f71
[geom] Fix Mesh connectivity for purely periodic meshes
holl- Jan 11, 2025
5326ff0
[field] Add Field.boundary_names
holl- Jan 12, 2025
e9d9907
[field] Avoid referencing non-existing Field dims
holl- Jan 12, 2025
3a5343a
[geom] Fix rotate()
holl- Jan 12, 2025
f9e3166
[tests] Remove plot from FLIP test
holl- Jan 12, 2025
137aee3
[geom] Fix Mesh.face_shape sizes
holl- Jan 12, 2025
9bb27e9
[field] Corner curl for 3D staggered grids
holl- Jan 21, 2025
cb78041
[geom] Set default surface_mesh() resolution to 128
Jan 13, 2025
1b2aec4
[geom] Compatibility with Python 3.8
Jan 14, 2025
43a27c7
[geom] Fix Mesh.boundary_faces
Jan 22, 2025
8bac536
[geom] Remove type hint from union()
Jan 22, 2025
3543588
[geom] Disable length() epsilon by default
Jan 22, 2025
b0795f1
[vis] Support Plotly 3D mesh union with colormap
Jan 24, 2025
a346e3b
[geom] Batched MeshBuilder
Jan 24, 2025
70f1052
[geom] Fix GeometryStack
holl- Jan 27, 2025
bfa19eb
[geom] Fix Mesh creation
holl- Jan 27, 2025
b0e93b1
[field] Allow stacking fields with incompatible geometries
holl- Jan 27, 2025
2fef175
[field] Support wide stencil laplace() for meshes
holl- Feb 2, 2025
ec3c45b
[geom] Cache Mesh.volume
Jan 26, 2025
25e7616
[vis] Fix plot batched meshes with user-defined color
Jan 26, 2025
014dbf8
[geom] Add squared_length()
Jan 28, 2025
991df52
[geom] Allow rot=None in rotate_vector()
Feb 3, 2025
64f38a1
[geom] Fix Geometry.bounding_sphere()
Feb 3, 2025
1e3d424
[geom] Track origin face in MeshBuilder
Feb 3, 2025
9f053ea
[geom] Add MeshBuilder.build_displaced_mesh()
Feb 5, 2025
cd7100b
[vis] Fix Plotly colors and item names
Feb 6, 2025
77f39b7
[physics] Simplify make_incompressible() for higher-order
Feb 6, 2025
bbd632d
[vis] Fix test__vis_base.py
Feb 6, 2025
aae7fa1
[tests,physics] Slightly loosen make_incompressible tolerance due to …
holl- Feb 6, 2025
7e75aa8
[vis] Fix animations
holl- Feb 6, 2025
f8f9de4
[vis] Fix Matplotlib VectorCloud2D plot
holl- Feb 6, 2025
cb5220a
[tests] Don't use rank-deficient solver in fluid test
holl- Feb 8, 2025
372ad31
[field] Fix resample() to Geometry
holl- Feb 8, 2025
bb5e15a
[geom] Fix Geometry comparison
holl- Feb 8, 2025
d55ec61
[Φ] Bump version to 3.3.0
Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions examples/mesh/FVM_Cylinder_GMsh.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,26 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %pip install --quiet phiflow\n",
"# %pip install --quiet phiflow meshio\n",
"from phi.torch.flow import *\n",
"# from phi.flow import * # If JAX is not installed. You can use phi.torch or phi.tf as well.\n",
"from tqdm.notebook import trange"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's download the example mesh file and visualize it!"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -59,10 +66,19 @@
}
],
"source": [
"!wget https://raw.githubusercontent.com/tum-pbs/PhiFlow/master/examples/mesh/cylinder.msh -O cylinder.msh\n",
"\n",
"mesh = geom.load_gmsh('cylinder.msh', ('y-', 'x+', 'y+', 'x-', 'cyl+', 'cyl-'))\n",
"plot(Box(x=6, y=6), mesh, overlay='args', size=(4, 3), title='cylinder.msh')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we can define the fluid dynamics. The momentum equation uses diffusion and advection while incompressibility is maintained by solving a separate linear system via `make_incompressible`."
]
},
{
"cell_type": "code",
"execution_count": 3,
Expand All @@ -82,6 +98,13 @@
" return v"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's set the boundary and initial conditions and run the simulation!"
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down
2 changes: 1 addition & 1 deletion phi/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.2.0
3.3.0
17 changes: 10 additions & 7 deletions phi/field/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from phi.math import Shape, Tensor, channel, non_batch, expand, instance, spatial, wrap, dual, non_dual
from phi.math.extrapolation import Extrapolation
from phi.math.magic import BoundDim, slicing_dict
from phiml.dataclasses import sliceable
from phiml.math import batch, Solve, DimFilter, unstack, concat_shapes, pack_dims, shape
from phiml.math.extrapolation import domain_slice

Expand Down Expand Up @@ -45,6 +46,7 @@ def __call__(cls,
return result


@sliceable
@dataclass(frozen=True)
class Field(metaclass=_FieldType):
"""
Expand Down Expand Up @@ -174,7 +176,7 @@ def numpy(self, order: DimFilter = None):
assert order is not None, f"order must be specified for non-uniform Field values"
order = self.values.shape.only(order, reorder=True)
stack_dims = order.non_uniform_shape
inner_order = order.without(stack_dims)
inner_order = order.without(stack_dims).names
return [v.numpy(inner_order) for v in unstack(self.values, stack_dims)]

def uniform_values(self):
Expand Down Expand Up @@ -206,7 +208,7 @@ def shape(self) -> Shape:
if self.is_grid and '~vector' in self.values.shape:
return batch(self.geometry) & self.resolution & non_dual(self.values).without(self.resolution) & self.geometry.shape['vector']
set_shape = self.geometry.sets[self.sampled_at]
return batch(self.geometry) & (channel(self.geometry) - 'vector') & set_shape & self.values
return batch(self.geometry) & (channel(self.geometry) - 'vector') & set_shape & self.values.shape

@property
def resolution(self):
Expand All @@ -232,7 +234,7 @@ def bounds(self) -> BaseBox:

Fields whose spatial rank is determined only during sampling return an empty `Box`.
"""
if isinstance(self.geometry.bounds, BaseBox):
if hasattr(self.geometry, 'bounds') and isinstance(self.geometry.bounds, BaseBox):
return self.geometry.bounds
extent = self.geometry.bounding_half_extent().vector.as_dual('_extent')
points = self.geometry.center + extent
Expand All @@ -242,6 +244,10 @@ def bounds(self) -> BaseBox:

box = bounds

@property
def boundary_names(self):
return tuple(self.geometry.boundary_faces)

@property
def is_grid(self):
"""A Field represents grid data if its `geometry` is a `phi.geom.UniformGrid` instance."""
Expand Down Expand Up @@ -662,7 +668,7 @@ def __getitem__(self, item) -> 'Field':
item = slicing_dict(self, item)
if not item:
return self
boundary = domain_slice(self.boundary, item, self.resolution)
boundary = domain_slice(self.boundary, item, domain_dims=self.boundary_names)
item_without_vec = {dim: selection for dim, selection in item.items() if dim != 'vector'}
geometry = self.geometry[item_without_vec]
if self.is_staggered and 'vector' in item and '~vector' in self.geometry.face_shape:
Expand All @@ -682,9 +688,6 @@ def __getitem__(self, item) -> 'Field':
values = self.values[item]
return Field(geometry, values, boundary)

def __getattr__(self, name: str) -> BoundDim:
return BoundDim(self, name)

def dimension(self, name: str):
"""
Returns a reference to one of the dimensions of this field.
Expand Down
6 changes: 4 additions & 2 deletions phi/field/_field_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from phi import geom, math
from phiml.math._shape import from_dict
from ._field import Field
from ._grid import unstack_staggered_tensor, CenteredGrid, StaggeredGrid
from ._field_math import stack
Expand Down Expand Up @@ -55,7 +56,7 @@ def write_single_field(field: Field, file: str):
field_type = 'StaggeredGrid' if field.is_staggered else 'CenteredGrid'
np.savez_compressed(file,
dim_names=dim_names,
dim_types=field.shape.types,
dim_types=field.shape.dim_types,
dim_item_names=np.asarray(field.shape.item_names, dtype=object),
field_type=field_type,
lower=lower,
Expand Down Expand Up @@ -103,7 +104,8 @@ def read_single_field(file: str, convert_to_backend=True) -> Field:
raise NotImplementedError(f"{ftype} not implemented")
data_arr = stored['data']
dim_item_names = stored.get('dim_item_names', (None,) * len(data_arr.shape))
data = tensor(data_arr, Shape(data_arr.shape, tuple(stored['dim_names']), tuple(stored['dim_types']), tuple(dim_item_names)), convert=convert_to_backend)
shape_spec = {'names': tuple(stored['dim_names']), 'sizes': data_arr.shape, 'types': tuple(stored['dim_types']), 'item_names': tuple(dim_item_names)}
data = tensor(data_arr, from_dict(shape_spec), convert=convert_to_backend)
bounds_item_names = stored.get('bounds_item_names', None)
if bounds_item_names is None or bounds_item_names.shape == (): # None or empty array
bounds_item_names = spatial(data).names
Expand Down
77 changes: 54 additions & 23 deletions phi/field/_field_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Callable, List, Tuple, Optional, Union, Sequence

import numpy as np
from phiml.math import Tensor, spatial, instance, tensor, channel, batch, Shape, unstack, solve_linear, jit_compile_linear, \
shape, Solve, extrapolation, dual, wrap, rename_dims, factorial, concat, zeros, ones
from phiml.math import Tensor, spatial, instance, tensor, channel, batch, Shape, unstack, solve_linear, \
jit_compile_linear, \
shape, Solve, extrapolation, dual, wrap, rename_dims, factorial, concat, zeros, ones, neighbor_mean
from phi import geom
from phi import math
from phi.geom import Box, Geometry, UniformGrid
Expand Down Expand Up @@ -50,7 +51,8 @@ def laplace(u: Field,
implicitness: int = None,
weights: Union[Tensor, Field] = None,
upwind: Field = None,
correct_skew=True) -> Field:
correct_skew=True,
wide_stencil=False) -> Field:
"""
Spatial Laplace operator for scalar grid.

Expand All @@ -77,7 +79,6 @@ def laplace(u: Field,
Returns:
laplacian field as `CenteredGrid`
"""

if implicitness is None:
implicitness = 0 if implicit is None else 2
elif implicitness != 0:
Expand All @@ -94,6 +95,11 @@ def laplace(u: Field,
raise NotImplementedError(f"laplace on meshes is not yet supported with vector-valued weights")
neighbor_val = u.mesh.pad_boundary(u.values, mode=u.boundary)
nb_distances = u.mesh.neighbor_distances
if wide_stencil:
assert weights is None
grad_p = spatial_gradient(u, order=order, scheme='green-gauss', upwind=upwind)
div_grad_p = grad_p.divergence(order=order, upwind=upwind)
return div_grad_p
connecting_grad = (u.mesh.connectivity * neighbor_val - u.values) / nb_distances # (T_N - T_P) / d_PN
if correct_skew and gradient is not None: # skewness correction
assert dual(gradient).names == ('~vector',), f"gradient must contain one dual dim '~vector' listing the gradient components but got {gradient.shape}"
Expand All @@ -109,41 +115,33 @@ def laplace(u: Field,
laplace_values = u.mesh.integrate_surface(grad) / u.mesh.volume # 1/V ∑_f ∇T ν A
result = weights * laplace_values if weights is not None else laplace_values
return Field(u.mesh, result, u.boundary - u.boundary)

# --- Grid ---

laplace_ext = u.extrapolation.spatial_gradient().spatial_gradient()
laplace_dims = u.shape.only(axes).names

if u.vector.exists and (u.is_centered or order > 2):
if 'vector' in u.shape and (u.is_centered or order > 2):
fields = [f for f in u.vector]
else:
fields = [u]

result = []
for f in fields:
if order == 2:
result.append(math.map_d2c(math.laplace)(f.values, dx=f.dx, padding=f.extrapolation, dims=axes, weights=weights, padding_kwargs={'bounds': f.bounds})) # uses ghost cells
else:
result_components = [perform_finite_difference_operation(f.values, dim, 2, f.dx.vector[dim], f.extrapolation,
laplace_ext, 'center', order, implicit,
implicitness) for dim in laplace_dims]
result_components = [perform_finite_difference_operation(f.values, dim, 2, f.dx.vector[dim], f.extrapolation, laplace_ext, 'center', order, implicit, implicitness) for dim in laplace_dims]
if weights is not None:
if channel(weights):
result_components = [c * weights[ax] for c, ax in zip(result_components, axes_names)]
else:
result_components = [c * weights for c in result_components]

result.append(sum(result_components))

if u.vector.exists and (u.is_centered or order > 2):
if 'vector' in u.shape and (u.is_centered or order > 2):
if u.is_staggered:
result = math.stack(result, dual(vector=u.vector.item_names))
else:
result = math.stack(result, channel(vector=u.vector.item_names))
else:
result = result[0]

return u.with_values(result).with_extrapolation(laplace_ext)


Expand Down Expand Up @@ -205,17 +203,14 @@ def spatial_gradient(field: Field,
return least_squares_gradient(field, stack_dim=stack_dim, boundary=boundary)
raise NotImplementedError(scheme)


if field.vector.exists:
if 'vector' in field.shape:
assert stack_dim.name != 'vector', "`stack_dim=vector` is inadmissible if the input is a vector grid"
if field == StaggeredGrid:
assert at == 'faces', "for a `StaggeredGrid` input only `type == StaggeredGrid` is possible"

if at == 'faces':
assert stack_dim.name == 'vector', f"spatial_gradient with type=StaggeredGrid requires stack_dim.name == 'vector' but got '{stack_dim.name}'"



if gradient_extrapolation is None:
gradient_extrapolation = field.extrapolation.spatial_gradient()

Expand Down Expand Up @@ -657,8 +652,8 @@ def curl(field: Field, at='corner'):
if field.is_grid and field.is_staggered and field.spatial_rank == 2 and at == 'corner':
x, y = field.vector.item_names
values = field.with_boundary(None).values
vx = math.pad(values.vector.dual[x], {y: (1, 1)}, field.extrapolation[{'vector': y}])
vy = math.pad(values.vector.dual[y], {x: (1, 1)}, field.extrapolation[{'vector': x}])
vx = math.pad(values.vector.dual[x], {y: (1, 1)}, field.boundary[{'vector': y}])
vy = math.pad(values.vector.dual[y], {x: (1, 1)}, field.boundary[{'vector': x}])
vy_dx = math.spatial_gradient(vy, dims=x, dx=field.dx[x], padding=None, stack_dim=None, difference='forward')
vx_dy = math.spatial_gradient(vx, dims=y, dx=field.dx[y], padding=None, stack_dim=None, difference='forward')
curl_val = vy_dx - vx_dy
Expand Down Expand Up @@ -699,6 +694,36 @@ def curl(field: Field, at='corner'):
# vy_dx = math.spatial_gradient(y_padded, field.dx, 'forward', None, dims='x', stack_dim=None)
# result = vy_dx - vx_dy
# return CenteredGrid(result, field.extrapolation.spatial_gradient(), field.bounds)
elif field.is_grid and field.is_staggered and field.spatial_rank == 3 and at == 'corner':
x, y, z = field.vector.item_names
values = field.with_boundary(None).values
vx = math.pad(values.vector.dual[x], {y: (1, 1), z: (1, 1)}, field.boundary[{'vector': y}])
vy = math.pad(values.vector.dual[y], {x: (1, 1), z: (1, 1)}, field.boundary[{'vector': x}])
vz = math.pad(values.vector.dual[z], {x: (1, 1), y: (1, 1)}, field.boundary[{'vector': z}])
vx_dy = neighbor_mean(math.spatial_gradient(vx, dims=y, dx=field.dx[y], padding=None, stack_dim=None, difference='forward'), z)
vx_dz = neighbor_mean(math.spatial_gradient(vx, dims=z, dx=field.dx[z], padding=None, stack_dim=None, difference='forward'), y)
vy_dx = neighbor_mean(math.spatial_gradient(vy, dims=x, dx=field.dx[x], padding=None, stack_dim=None, difference='forward'), z)
vy_dz = neighbor_mean(math.spatial_gradient(vy, dims=z, dx=field.dx[z], padding=None, stack_dim=None, difference='forward'), x)
vz_dx = neighbor_mean(math.spatial_gradient(vz, dims=x, dx=field.dx[x], padding=None, stack_dim=None, difference='forward'), y)
vz_dy = neighbor_mean(math.spatial_gradient(vz, dims=y, dx=field.dx[y], padding=None, stack_dim=None, difference='forward'), x)
curl_val = math.stack([vz_dy-vy_dz, vx_dz-vz_dx, vy_dx-vx_dy], field.shape['vector'])
corners = UniformGrid(field.resolution + 1, Box(field.bounds.lower - field.dx / 2, field.bounds.upper + field.dx / 2))
return Field(corners, curl_val, field.boundary.spatial_gradient())
elif field.is_grid and field.is_centered and field.spatial_rank == 3 and at == 'corner':
raise NotImplementedError
x, y, z = field.vector.item_names
values = pad(field, 1).values
# ToDo 8 diag offset vectors, account for cell stretching
# Then sum (offset x v) / |offset|^2 ??
diag_basis = wrap([(1, 1), (1, -1)], channel(diag='pos,neg'), dual(vector=[x, y]))
diag_comp = diag_basis @ values
ll = diag_comp[{x: slice(-1), y: slice(-1), 'diag': 'neg'}]
ul = diag_comp[{x: slice(-1), y: slice(1, None), 'diag': 'pos'}]
lr = diag_comp[{x: slice(1, None), y: slice(-1), 'diag': 'pos'}]
ur = diag_comp[{x: slice(1, None), y: slice(1, None), 'diag': 'neg'}]
curl_val = ll - ul + lr - ur
corners = UniformGrid(field.resolution + 1, Box(field.bounds.lower - field.dx / 2, field.bounds.upper + field.dx / 2))
return Field(corners, curl_val, field.boundary.spatial_gradient())
raise NotImplementedError("Only 2D curl at corner currently supported")


Expand Down Expand Up @@ -762,7 +787,10 @@ def mean(field: Field, dim=lambda s: s.non_channel.non_batch) -> Tensor:
Returns:
`phi.Tensor`
"""
result = math.mean(field.values, dim=dim)
if field.is_grid:
result = math.mean(field.values, dim=dim)
else:
result = math.mean(field.values, dim=dim, weight=field.geometry.volume)
if (instance(field.geometry) & spatial(field.geometry)) in result.shape:
return field.with_values(result)
return result
Expand Down Expand Up @@ -948,7 +976,10 @@ def stack(fields: Sequence[Field], dim: Shape, dim_bounds: Box = None):
return fields[0].with_values(values).with_boundary(boundary)
else:
values = math.stack([f.values for f in fields], dim)
geometry = fields[0].geometry if all(f.geometry == fields[0].geometry for f in fields) else math.stack([f.geometry for f in fields], dim)
geometry = fields[0].geometry if all(f.geometry == fields[0].geometry for f in fields) else math.stack([f.geometry for f in fields], dim, layout_non_matching=True)
if isinstance(geometry, Tensor):
from phi.geom._geom_ops import GeometryStack
geometry = GeometryStack(geometry)
return Field(geometry, values, boundary)


Expand Down
4 changes: 2 additions & 2 deletions phi/field/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def StaggeredGrid(values: Any = 0.,
values = sample(values, elements, at='face', boundary=extrapolation, dot_face_normal=elements)
elif callable(values):
values = sample_function(values, elements, 'face', extrapolation)
if elements.shape.shape.rank > 1: # Different number of X and Y faces
if elements.shape.is_non_uniform: # Different number of X and Y faces
assert isinstance(values, TensorStack), f"values function must return a staggered Tensor but returned {type(values)}"
assert '~vector' in values.shape
if 'vector' in values.shape:
Expand Down Expand Up @@ -203,7 +203,7 @@ def resolution_from_staggered_tensor(values: Tensor, extrapolation: Extrapolatio
x_shape = values.shape.after_gather({'vector': any_dim, '~vector': any_dim})
ext_lower, ext_upper = extrapolation.valid_outer_faces(any_dim)
delta = int(ext_lower) + int(ext_upper) - 1
resolution = x_shape.spatial._replace_single_size(any_dim, x_shape.get_size(any_dim) - delta)
resolution = x_shape.spatial.with_dim_size(any_dim, x_shape.get_size(any_dim) - delta)
return resolution


Expand Down
6 changes: 4 additions & 2 deletions phi/field/_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from phi.math.extrapolation import Extrapolation, ConstantExtrapolation, PERIODIC
from phiml.math import unstack, channel, rename_dims, batch, extrapolation
from ._field import Field, FieldInitializer, as_boundary, slice_off_constant_faces
from phiml.math._tensors import may_vary_along
from phiml.math._tensors import may_vary_along, wrap


def resample(value: Union[Field, Geometry, Tensor, float, FieldInitializer], to: Union[Field, Geometry], keep_boundary=False, **kwargs):
Expand Down Expand Up @@ -48,7 +48,9 @@ def resample(value: Union[Field, Geometry, Tensor, float, FieldInitializer], to:
>>> field.resample(grid, to=grid) == grid
True
"""
assert isinstance(to, (Field, Geometry)), f"'to' must be a Field or Geomoetry but got {to}"
assert isinstance(to, (Field, Geometry)), f"'to' must be a Field or Geometry but got {to}"
if isinstance(to, Geometry):
to = Field(to, wrap(0.), value.extrapolation if isinstance(value, Field) else 0.)
if not isinstance(value, (Field, Geometry, FieldInitializer)):
return to.with_values(value)
if isinstance(value, Field) and keep_boundary:
Expand Down
Loading
Loading