diff --git a/jax_cfd/__init__.py b/jax_cfd/__init__.py deleted file mode 100644 index 6569478..0000000 --- a/jax_cfd/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Defines the JAX-CFD module for computational fluid dynamics.""" - -__version__ = '0.2.1' - -import jax_cfd.base diff --git a/jax_cfd/base/IBM_Force.py b/jax_cfd/base/IBM_Force.py new file mode 100644 index 0000000..45b93bb --- /dev/null +++ b/jax_cfd/base/IBM_Force.py @@ -0,0 +1,115 @@ +import jax.numpy as jnp +import jax +from jax_ib.base import grids + +def integrate_trapz(integrand,dx,dy): + return jnp.trapz(jnp.trapz(integrand,dx=dx),dx=dy) + + +def Integrate_Field_Fluid_Domain(field): + + + grid = field.grid + # offset = field.offset + dxEUL = grid.step[0] + dyEUL = grid.step[1] + # X,Y =grid.mesh(offset) + + return integrate_trapz(field.data,dxEUL,dyEUL) + +def IBM_force_GENERAL(field,Xi,particle_center,geom_param,Grid_p,shape_fn,discrete_fn,surface_fn,dx_dt,domega_dt,rotation,dt): + + grid = field.grid + offset = field.offset + X,Y = grid.mesh(offset) + dxEUL = grid.step[0] + dyEUL = grid.step[1] + current_t = field.bc.time_stamp + #current_t = 0.0 + xp0,yp0 = shape_fn(geom_param,Grid_p) + #print('yp',yp0,'xp',xp0) + #print('angle',current_t,rotation(current_t),particle_center) + #print(yp0) + xp = (xp0)*jnp.cos(rotation(current_t))-(yp0)*jnp.sin(rotation(current_t))+particle_center[0] + yp = (xp0)*jnp.sin(rotation(current_t))+(yp0 )*jnp.cos(rotation(current_t))+particle_center[1] + surface_coord =[(xp)/dxEUL-offset[0],(yp)/dyEUL-offset[1]] + #print(rotation(current_t)) + velocity_at_surface = surface_fn(field,xp,yp) + + if Xi==0: + position_r = -(yp-particle_center[1]) + elif Xi==1: + position_r = (xp-particle_center[0]) + + U0 = dx_dt(current_t) + #print('U0',U0) + Omega=domega_dt(current_t) + UP= U0[Xi] + Omega*position_r + #print(xp) + #print('XI',Xi,UP,len(UP)) + force = (UP -velocity_at_surface)/dt + + # if Xi==0: + #plt.plot(xp,force) + #maxforce = delta_approx_logistjax(xp[0],X,0.003,1) + # maxforce = discrete_fn(xp[3],X) + # plt.imshow(maxforce) + # print('Maxforce',jnp.max(maxforce)) + # print(xp) + x_i = jnp.roll(xp,-1) + y_i = jnp.roll(yp,-1) + dxL = x_i-xp + dyL = y_i-yp + dS = jnp.sqrt(dxL**2 + dyL**2) + + + def calc_force(F,xp,yp,dxi,dyi,dss): + return F*discrete_fn(jnp.sqrt((xp-X)**2 + (yp-Y)**2),0,dxEUL)*dss + #return F*discrete_fn(xp-X,0,dxEUL)*discrete_fn(yp-Y,0,dyEUL)*dss + #return F*discrete_fn(xp,X,dxEUL)*discrete_fn(yp,Y,dyEUL)*dss**2 + def foo(tree_arg): + F,xp,yp,dxi,dyi,dss = tree_arg + return calc_force(F,xp,yp,dxi,dyi,dss) + + def foo_pmap(tree_arg): + #print(tree_arg) + return jnp.sum(jax.vmap(foo,in_axes=1)(tree_arg),axis=0) + divider=jax.device_count() + n = len(xp)//divider + mapped = [] + for i in range(divider): + # print(i) + mapped.append([force[i*n:(i+1)*n],xp[i*n:(i+1)*n],yp[i*n:(i+1)*n],dxL[i*n:(i+1)*n],dyL[i*n:(i+1)*n],dS[i*n:(i+1)*n]]) + #mapped = jnp.array([force,xp,yp]) + #remapped = mapped.reshape(())#jnp.array([[force[:n],xp[:n],yp[:n]],[force[n:],xp[n:],yp[n:]]]) + + #return cfd.grids.GridArray(jnp.sum(jax.pmap(foo_pmap)(jnp.array(mapped)),axis=0),offset,grid) + return jnp.sum(jax.pmap(foo_pmap)(jnp.array(mapped)),axis=0) + +def IBM_Multiple_NEW(field, Xi, particles,discrete_fn,surface_fn,dt): + Grid_p = particles.generate_grid() + shape_fn = particles.shape + Displacement_EQ = particles.Displacement_EQ + Rotation_EQ = particles.Rotation_EQ + Nparticles = len(particles.particle_center) + particle_center = particles.particle_center + geom_param = particles.geometry_param + displacement_param = particles.displacement_param + rotation_param = particles.rotation_param + force = jnp.zeros_like(field.data) + for i in range(Nparticles): + Xc = lambda t:Displacement_EQ([displacement_param[i]],t) + rotation = lambda t:Rotation_EQ([rotation_param[i]],t) + dx_dt = jax.jacrev(Xc) + domega_dt = jax.jacrev(rotation) + force+= IBM_force_GENERAL(field,Xi,particle_center[i],geom_param[i],Grid_p,shape_fn,discrete_fn,surface_fn,dx_dt,domega_dt,rotation,dt) + return grids.GridArray(force,field.offset,field.grid) + + +def calc_IBM_force_NEW_MULTIPLE(all_variables,discrete_fn,surface_fn,dt): + velocity = all_variables.velocity + particles = all_variables.particles + axis = [0,1] + ibm_forcing = lambda field,Xi:IBM_Multiple_NEW(field, Xi, particles,discrete_fn,surface_fn,dt) + + return tuple(grids.GridVariable(ibm_forcing(field,Xi),field.bc) for field,Xi in zip(velocity,axis)) diff --git a/jax_cfd/base/__init__.py b/jax_cfd/base/__init__.py index ea4d7a6..c38e258 100644 --- a/jax_cfd/base/__init__.py +++ b/jax_cfd/base/__init__.py @@ -1,33 +1,16 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import jax_ib.base.IBM_Force +import jax_ib.base.boundaries +import jax_ib.base.convolution_functions +import jax_ib.base.equations +import jax_ib.base.grids +import jax_ib.base.particle_class +import jax_ib.base.particle_motion +import jax_ib.base.pressure +import jax_ib.base.time_stepping -"""Non-learned "base" physics routines for JAX-CFD.""" - -import jax_cfd.base.advection -import jax_cfd.base.array_utils -import jax_cfd.base.boundaries -import jax_cfd.base.diffusion -import jax_cfd.base.equations -import jax_cfd.base.fast_diagonalization -import jax_cfd.base.finite_differences -import jax_cfd.base.forcings -import jax_cfd.base.funcutils -import jax_cfd.base.grids -import jax_cfd.base.initial_conditions -import jax_cfd.base.interpolation -import jax_cfd.base.pressure -import jax_cfd.base.resize -import jax_cfd.base.subgrid_models -import jax_cfd.base.time_stepping -import jax_cfd.base.validation_problems +import jax_ib.base.advection +import jax_ib.base.interpolation +import jax_ib.base.diffusion +import jax_ib.base.finite_differences +import jax_ib.base.kinematics +import jax_ib.base.array_utils diff --git a/jax_cfd/base/advection.py b/jax_cfd/base/advection.py index f4afd5b..c304c45 100644 --- a/jax_cfd/base/advection.py +++ b/jax_cfd/base/advection.py @@ -15,13 +15,12 @@ """Module for functionality related to advection.""" from typing import Optional, Tuple - import jax import jax.numpy as jnp -from jax_cfd.base import boundaries -from jax_cfd.base import finite_differences as fd -from jax_cfd.base import grids -from jax_cfd.base import interpolation +from jax_ib.base import boundaries +from jax_ib.base import finite_differences as fd +from jax_ib.base import grids +from jax_ib.base import interpolation GridArray = grids.GridArray GridArrayVector = grids.GridArrayVector @@ -71,10 +70,8 @@ def _advect_aligned(cs: GridVariableVector, v: GridVariableVector) -> GridArray: raise ValueError('`cs` and `v` must have the same length;' f'got {len(cs)} vs. {len(v)}.') flux = tuple(c.array * u.array for c, u in zip(cs, v)) - bcs = tuple( - boundaries.get_advection_flux_bc_from_velocity_and_scalar(v[i], cs[i], i) - for i in range(len(v))) - flux = tuple(bc.impose_bc(f) for f, bc in zip(flux, bcs)) + # Flux inherits boundary conditions from cs + flux = tuple(grids.GridVariable(f, c.bc) for f, c in zip(flux, cs)) return -fd.divergence(flux) @@ -105,9 +102,6 @@ def advect_general( Returns: The time derivative of `c` due to advection by `v`. """ - if not boundaries.has_all_periodic_boundary_conditions(c): - raise NotImplementedError( - 'Non-periodic boundary conditions are not implemented.') target_offsets = grids.control_volume_offsets(c) aligned_v = tuple(u_interpolation_fn(u, target_offset, v, dt) for u, target_offset in zip(v, target_offsets)) @@ -142,6 +136,7 @@ def _align_velocities(v: GridVariableVector) -> Tuple[GridVariableVector]: the appropriate face of the control volume centered around `v[j]`. """ grid = grids.consistent_grid(*v) + #grid = v[0].grid offsets = tuple(grids.control_volume_offsets(u) for u in v) aligned_v = tuple( tuple(interpolation.linear(v[i], offsets[i][j]) @@ -172,10 +167,11 @@ def _velocities_to_flux( for i in range(ndim): for j in range(ndim): if i <= j: - bc = boundaries.get_advection_flux_bc_from_velocity_and_scalar( - aligned_v[j][i], aligned_v[i][j], j) - flux[i] += (bc.impose_bc(aligned_v[i][j].array * - aligned_v[j][i].array),) + bc = grids.consistent_boundary_conditions( + aligned_v[i][j], aligned_v[j][i]) + #bc = aligned_v[i][j].bc + flux[i] += (GridVariable(aligned_v[i][j].array * aligned_v[j][i].array, + bc),) else: flux[i] += (flux[j][i],) return tuple(flux) @@ -210,8 +206,7 @@ def convect_linear(v, grid): def advect_van_leer( c: GridVariable, v: GridVariableVector, - dt: float, - mode: str = boundaries.Padding.MIRROR, + dt: float ) -> GridArray: """Computes advection of a scalar quantity `c` by the velocity field `v`. @@ -221,28 +216,18 @@ def advect_van_leer( limitor transformes the scheme into a first order method. For [1] for reference. This function follows the following procedure: - 1. Shifts c to offset < 1 if necessary. - 2. Scalar c now has a well defined right-hand (upwind) value. - 3. Computes upwind flux for each direction. - 4. Computes van leer flux limiter: - a. Use the shifted c to interpolate each component of `v` to the - right-hand (upwind) face of the control volume centered on `c`. - b. Compute the ratio of successive gradients: - In nonperiodic case, the value outside the boundary is not defined. - Mode is used to interpolate past the boundary. - c. Compute flux limiter function. - d. Computes higher order flux correction. - 5. Combines fluxes and assigns flux boundary condition. - 6. Computes the negative divergence of fluxes. - 7. Shifts the computed values back to original offset of c. + 1. Interpolate each component of `v` to the corresponding face of the + control volume centered on `c`. In most cases satisfied by design. + 2. Computes upwind flux for each direction. + 3. Computes higher order flux correction based on neighboring values of `c`. + 4. Combines fluxes and assigns flux boundary condition. + 5. Returns the negative divergence of fluxes. Args: c: the quantity to be transported. v: a velocity field. Should be defined on the same Grid as c. dt: time step for which this scheme is TVD and second order accurate in time. - mode: For non-periodic BC, specifies extrapolation of values beyond the - boundary, which is used by nonlinear interpolation. Returns: The time derivative of `c` due to advection by `v`. @@ -251,94 +236,47 @@ def advect_van_leer( [1]: MIT 18.336 spring 2009 Finite Volume Methods Lecture 19. go/mit-18.336-finite_volume_methods-19 - [2]: - www.ita.uni-heidelberg.de/~dullemond/lectures/num_fluid_2012/Chapter_4.pdf """ # TODO(dkochkov) reimplement this using apply_limiter method. - c_left_var = c - # if the offset is 1., shift by 1 to offset 0. - # otherwise c_right is not defined. - for ax in range(c.grid.ndim): - # int(c.offset[ax] % 1 - c.offset[ax]) = -1 if c.offset[ax] is 1 else - # int(c.offset[ax] % 1 - c.offset[ax]) = 0. - # i.e. this shifts the 1 aligned data to 0 offset, the rest is unchanged. - c_left_var = c.bc.impose_bc( - c_left_var.shift(int(c.offset[ax] % 1 - c.offset[ax]), axis=ax)) - offsets = grids.control_volume_offsets(c_left_var) - # if c offset is 0, aligned_v is at 0.5. - # if c offset is at .5, aligned_v is at 1. + offsets = grids.control_volume_offsets(c) aligned_v = tuple(interpolation.linear(u, offset) for u, offset in zip(v, offsets)) flux = [] - # Assign flux boundary condition - flux_bc = [ - boundaries.get_advection_flux_bc_from_velocity_and_scalar(u, c, direction) - for direction, u in enumerate(v) - ] - # first, compute upwind flux. - for axis, u in enumerate(aligned_v): - c_center = c_left_var.data - # by shifting c_left + 1, c_right is well-defined. - c_right = c_left_var.shift(+1, axis=axis).data + for axis, (u, h) in enumerate(zip(aligned_v, c.grid.step)): + c_center = c.data + c_left = c.shift(-1, axis=axis).data + c_right = c.shift(+1, axis=axis).data upwind_flux = grids.applied(jnp.where)( u.array > 0, u.array * c_center, u.array * c_right) - flux.append(upwind_flux) - # next, compute van_leer correction. - for axis, (u, h) in enumerate(zip(aligned_v, c.grid.step)): - u = u.bc.shift(u.array, int(u.offset[axis] % 1 - u.offset[axis]), axis=axis) - # c is put to offset .5 or 1. - c_center_arr = c.shift(int(1 - c.offset[ax]), axis=ax) - # if c offset is 1, u offset is .5. - # if c offset is .5, u offset is 0. - # u_i is always on the left of c_center_var_i - c_center = c_center_arr.data - # shift -1 are well defined now - # shift +1 is not well defined for c offset 1 because then c(wall + 1) is - # not defined. - # However, the flux that uses c(wall + 1) offset gets overridden anyways - # when flux boundary condition is overridden. - # Thus, any mode can be used here. - c_right = c.bc.shift(c_center_arr, +1, axis=axis, mode=mode).data - c_left = c.bc.shift(c_center_arr, -1, axis=axis).data - # shift -2 is tricky: - # It is well defined if c is periodic. - # Else, c(-1) or c(-1.5) are not defined. - # Then, mode is used to interpolate the values. - c_left_left = c.bc.shift( - c_center_arr, -2, axis, mode=mode).data - - numerator_positive = c_left - c_left_left - numerator_negative = c_right - c_center - numerator = grids.applied(jnp.where)(u > 0, numerator_positive, - numerator_negative) - denominator = grids.GridArray(c_center - c_left, u.offset, u.grid) - # We want to calculate denominator / (abs(denominator) + abs(numerator)) - # To make it differentiable, it needs to be done in stages. - - # ensures that there is no division by 0 - phi_van_leer_denominator_avoid_nans = grids.applied(jnp.where)( - abs(denominator) > 0, (abs(denominator) + abs(numerator)), 1.) - - phi_van_leer_denominator_inv = denominator / phi_van_leer_denominator_avoid_nans - - phi_van_leer = numerator * (grids.applied(jnp.sign)(denominator) + - grids.applied(jnp.sign) - (numerator)) * phi_van_leer_denominator_inv - abs_velocity = abs(u) + + # Van-Leer Flux correction is computed in steps to avoid `nan`s. + # Formula for the flux correction df for advection with positive velocity is + # df_{i} = 0.5 * (1-gamma) * dc_{i} + # dc_{i} = 2(c_{i+1} - c_{i})(c_{i} - c_{i-1})/(c_{i+1}-c_{i}) + # gamma is the courant number = u * dt / h + diffs_prod = 2 * (c_right - c_center) * (c_center - c_left) + neighbor_diff = c_right - c_left + safe = diffs_prod > 0 + # https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where + forward_correction = jnp.where( + safe, diffs_prod / jnp.where(safe, neighbor_diff, 1), 0 + ) + # for negative velocity we simply need to shift the correction along v axis. + # Cast to GridVariable so that we can apply a shift() operation. + forward_correction_array = grids.GridVariable( + grids.GridArray(forward_correction, u.offset, u.grid), u.bc) + backward_correction_array = forward_correction_array.shift(+1, axis) + backward_correction = backward_correction_array.data + abs_velocity = abs(u.array) courant_numbers = (dt / h) * abs_velocity pre_factor = 0.5 * (1 - courant_numbers) * abs_velocity - flux_correction = pre_factor * phi_van_leer - # Shift back onto original offset. - flux_correction = flux_bc[axis].shift( - flux_correction, int(offsets[axis][axis] - u.offset[axis]), axis=axis) - flux[axis] += flux_correction - flux = tuple(flux_bc[axis].impose_bc(f) for axis, f in enumerate(flux)) + flux_correction = pre_factor * grids.applied(jnp.where)( + u.array > 0, forward_correction, backward_correction) + flux.append(upwind_flux + flux_correction) + # Assign flux boundary condition + flux = tuple(GridVariable(f, c.bc) for f in flux) advection = -fd.divergence(flux) - # shift the variable back onto the original offset - for ax in range(c.grid.ndim): - advection = c.bc.shift( - advection, -int(c.offset[ax] % 1 - c.offset[ax]), axis=ax) return advection diff --git a/jax_cfd/base/array_utils.py b/jax_cfd/base/array_utils.py index d40d153..9cff5b8 100644 --- a/jax_cfd/base/array_utils.py +++ b/jax_cfd/base/array_utils.py @@ -18,8 +18,8 @@ import jax import jax.numpy as jnp -from jax_cfd.base import boundaries -from jax_cfd.base import grids +from jax_ib.base import boundaries +from jax_ib.base import grids import numpy as np import scipy.linalg @@ -57,7 +57,7 @@ def slice_along_axis( Returns: Slice of `inputs` defined by `idx` along axis `axis`. """ - arrays, tree_def = jax.tree.flatten(inputs) + arrays, tree_def = jax.tree_flatten(inputs) ndims = set(a.ndim for a in arrays) if expect_same_dims and len(ndims) != 1: raise ValueError('arrays in `inputs` expected to have same ndims, but have ' @@ -68,7 +68,7 @@ def slice_along_axis( slc = tuple(idx if j == _normalize_axis(axis, ndim) else slice(None) for j in range(ndim)) sliced.append(array[slc]) - return jax.tree.unflatten(tree_def, sliced) + return jax.tree_unflatten(tree_def, sliced) def split_along_axis( @@ -115,22 +115,22 @@ def split_axis( Raises: ValueError: if arrays in `inputs` don't have unique size along `axis`. """ - arrays, tree_def = jax.tree.flatten(inputs) + arrays, tree_def = jax.tree_flatten(inputs) axis_shapes = set(a.shape[axis] for a in arrays) if len(axis_shapes) != 1: raise ValueError(f'Arrays must have equal sized axis but got {axis_shapes}') axis_shape, = axis_shapes splits = [jnp.split(a, axis_shape, axis=axis) for a in arrays] if not keep_dims: - splits = jax.tree.map(lambda a: jnp.squeeze(a, axis), splits) + splits = jax.tree_map(lambda a: jnp.squeeze(a, axis), splits) splits = zip(*splits) - return tuple(jax.tree.unflatten(tree_def, leaves) for leaves in splits) + return tuple(jax.tree_unflatten(tree_def, leaves) for leaves in splits) def concat_along_axis(pytrees, axis): """Concatenates `pytrees` along `axis`.""" concat_leaves_fn = lambda *args: jnp.concatenate(args, axis) - return jax.tree.map(concat_leaves_fn, *pytrees) + return jax.tree_map(concat_leaves_fn, *pytrees) def block_reduce( @@ -172,6 +172,14 @@ def laplacian_matrix(size: int, step: float) -> np.ndarray: column[1] = column[-1] = 1 / step**2 return scipy.linalg.circulant(column) +def laplacian_matrix_neumann(size: int, step: float) -> np.ndarray: + """Create 1D Laplacian operator matrix, with homogeneous Neumann BC.""" + column = np.zeros(size) + column[0] = -2 / step ** 2 + column[1] = 1 / step ** 2 + matrix = scipy.linalg.toeplitz(column) + matrix[0, 0] = matrix[-1, -1] = -1 / step**2 + return matrix def _laplacian_boundary_dirichlet_cell_centered(laplacians: List[Array], grid: grids.Grid, axis: int, diff --git a/jax_cfd/base/boundaries.py b/jax_cfd/base/boundaries.py index 275aad2..3756c1f 100644 --- a/jax_cfd/base/boundaries.py +++ b/jax_cfd/base/boundaries.py @@ -14,20 +14,22 @@ """Classes that specify how boundary conditions are applied to arrays.""" import dataclasses -import math -from typing import Optional, Sequence, Tuple, Union - -import jax +from typing import Any, Callable, Iterable, Sequence, Tuple, Optional, Union from jax import lax +import jax import jax.numpy as jnp -from jax_cfd.base import grids +from jax_ib.base import grids import numpy as np +import scipy +from jax.tree_util import register_pytree_node_class +from jax_ib.base import particle_class BoundaryConditions = grids.BoundaryConditions GridArray = grids.GridArray GridVariable = grids.GridVariable GridVariableVector = grids.GridVariableVector Array = Union[np.ndarray, jax.Array] +BCArray = grids.BCArray class BCType: @@ -35,13 +37,8 @@ class BCType: DIRICHLET = 'dirichlet' NEUMANN = 'neumann' - -class Padding: - MIRROR = 'mirror' - EXTEND = 'extend' - - -@dataclasses.dataclass(init=False, frozen=True) +@register_pytree_node_class +@dataclasses.dataclass(init=False, frozen=False) class ConstantBoundaryConditions(BoundaryConditions): """Boundary conditions for a PDE variable that are constant in space and time. @@ -59,20 +56,47 @@ class ConstantBoundaryConditions(BoundaryConditions): """ types: Tuple[Tuple[str, str], ...] bc_values: Tuple[Tuple[Optional[float], Optional[float]], ...] - - def __init__(self, types: Sequence[Tuple[str, str]], - values: Sequence[Tuple[Optional[float], Optional[float]]]): + boundary_fn: Callable[...,Optional[float]] + time_stamp: Optional[float] + def __init__(self, + time_stamp: Optional[float],values: Sequence[Tuple[Optional[float], Optional[float]]],types: Sequence[Tuple[str, str]],boundary_fn:Callable[...,Optional[float]]): types = tuple(types) values = tuple(values) - object.__setattr__(self, 'types', types) + boundary_fn = boundary_fn + time_stamp = time_stamp + object.__setattr__(self, 'bc_values', values) + object.__setattr__(self, 'boundary_fn', boundary_fn) + object.__setattr__(self, 'time_stamp', time_stamp if time_stamp is not None else []) + object.__setattr__(self, 'types', types) + #if boundary_fn or not boundary_fn: + + #else: + # object.__setattr__(self, 'boundary_fn', None) + # object.__setattr__(self, 'time_stamp', None) + + + def tree_flatten(self): + """Returns flattening recipe for GridVariable JAX pytree.""" + children = (self.time_stamp,self.bc_values,) + aux_data = (self.types,self.boundary_fn) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for GridVariable JAX pytree.""" + return cls(*children, *aux_data) + + + def update_bc_(self,time_stamp: float, dt: float): + return time_stamp + dt + def shift( self, u: GridArray, offset: int, axis: int, - mode: Optional[str] = Padding.EXTEND, ) -> GridArray: """Shift an GridArray by `offset`. @@ -80,53 +104,20 @@ def shift( u: an `GridArray` object. offset: positive or negative integer offset to shift. axis: axis to shift along. - mode: type of padding to use in non-periodic case. - Mirror mirrors the flow across the boundary. - Extend extends the last well-defined value past the boundary. Returns: A copy of `u`, shifted by `offset`. The returned `GridArray` has offset `u.offset + offset`. """ - padded = self._pad(u, offset, axis, mode=mode) + padded = self._pad(u, offset, axis) trimmed = self._trim(padded, -offset, axis) return trimmed - def _is_aligned(self, u: GridArray, axis: int) -> bool: - """Checks if array u contains all interior domain information. - - For dirichlet edge aligned boundary, the value that lies exactly on the - boundary does not have to be specified by u. - Neumann edge aligned boundary is not defined. - - Args: - u: array that should contain interior data - axis: axis along which to check - - Returns: - True if u is aligned, and raises error otherwise. - """ - size_diff = u.shape[axis] - u.grid.shape[axis] - if self.types[axis][0] == BCType.DIRICHLET and np.isclose( - u.offset[axis], 1): - size_diff += 1 - if self.types[axis][1] == BCType.DIRICHLET and np.isclose( - u.offset[axis], 1): - size_diff += 1 - if self.types[axis][0] == BCType.NEUMANN and np.isclose( - u.offset[axis] % 1, 0): - raise NotImplementedError('Edge-aligned neumann BC are not implemented.') - if size_diff < 0: - raise ValueError( - 'the GridArray does not contain all interior grid values.') - return True - def _pad( self, u: GridArray, width: int, axis: int, - mode: Optional[str] = Padding.EXTEND, ) -> GridArray: """Pad a GridArray. @@ -136,22 +127,18 @@ def _pad( ghost cell is required. More ghost cells are used only in LES filtering/CNN application. + No padding past 1 ghost cell is implemented for Neumann BC. + Args: u: a `GridArray` object. width: number of elements to pad along axis. Use negative value for lower boundary or positive value for upper boundary. axis: axis to pad along. - mode: type of padding to use in non-periodic case. - Mirror mirrors the array values across the boundary. - Extend extends the last well-defined array value past the boundary. - Mode is only needed if the padding extends past array values that are - defined by the physics. In these cases, no mode is necessary. This - also means periodic boundaries do not require a mode and can use - mode=None. Returns: Padded array, elongated along the indicated axis. """ + def make_padding(width): if width < 0: # pad lower boundary bc_type = self.types[axis][0] @@ -167,13 +154,12 @@ def make_padding(width): full_padding, padding, bc_type = make_padding(width) offset = list(u.offset) offset[axis] -= padding[0] - if bc_type == BCType.PERIODIC: - need_trimming = 'both' # need to trim both sides - elif width >= 0: - need_trimming = 'right' # only one side needs to be trimmed - else: - need_trimming = 'left' # only one side needs to be trimmed - u, trimmed_padding = self._trim_padding(u, axis, need_trimming) + if not (bc_type == BCType.PERIODIC or + bc_type == BCType.DIRICHLET) and abs(width) > 1: + raise ValueError( + f'Padding past 1 ghost cell is not defined in {bc_type} case.') + + u, trimmed_padding = self._trim_padding(u) data = u.data full_padding[axis] = tuple( pad + trimmed_pad @@ -182,63 +168,62 @@ def make_padding(width): if bc_type == BCType.PERIODIC: # for periodic, all grid points must be there. Otherwise padding doesn't # make sense. - + # Don't pad a trimmed periodic array. + if u.grid.shape[axis] > u.shape[axis]: + raise ValueError('the GridArray shape does not match the grid.') # self.values are ignored here pad_kwargs = dict(mode='wrap') data = jnp.pad(data, full_padding, **pad_kwargs) elif bc_type == BCType.DIRICHLET: if np.isclose(u.offset[axis] % 1, 0.5): # cell center - # If only one or 0 value is needed, no mode is necessary. - # All modes would return the same values. - if np.isclose(sum(full_padding[axis]), 1) or np.isclose( - sum(full_padding[axis]), 0): - mode = Padding.MIRROR - - if mode == Padding.MIRROR: - # make the linearly interpolated value equal to the boundary by - # setting the padded values to the negative symmetric values - data = (2 * jnp.pad( - data, - full_padding, - mode='constant', - constant_values=self.bc_values) - - jnp.pad(data, full_padding, mode='symmetric')) - elif mode == Padding.EXTEND: - # computes the well-defined ghost cell and sets the rest of padding - # values equal to the ghost cell. - data = (2 * jnp.pad( - data, - full_padding, - mode='constant', - constant_values=self.bc_values) - - jnp.pad(data, full_padding, mode='edge')) - else: - raise NotImplementedError(f'Mode {mode} is not implemented yet.') + # make the linearly interpolated value equal to the boundary by setting + # the padded values to the negative symmetric values + # for dirichlet 0.5 offset, all grid points must be there. + # Otherwise padding doesn't make sense. + if u.grid.shape[axis] > u.shape[axis]: + raise ValueError('the GridArray shape does not match the grid.') + data = (2 * jnp.pad( + data, full_padding, mode='constant', constant_values=self.bc_values) + - jnp.pad(data, full_padding, mode='symmetric')) elif np.isclose(u.offset[axis] % 1, 0): # cell edge # u specifies the values on the interior CV. Thus, first the value on # the boundary needs to be added to the array, if not specified by the # interior CV values. # Then the mirrored ghost cells need to be appended. - # if only one value is needed, no mode is necessary. - if np.isclose(sum(full_padding[axis]), 1) or np.isclose( - sum(full_padding[axis]), 0): - data = jnp.pad( - data, - full_padding, - mode='constant', - constant_values=self.bc_values) - elif sum(full_padding[axis]) > 1: - if mode == Padding.MIRROR: - # make boundary-only padding - bc_padding = [(0, 0)] * u.grid.ndim - bc_padding[axis] = tuple( - 1 if pad > 0 else 0 for pad in full_padding[axis]) + # for dirichlet cell-face aligned offset, 1 grid_point can be missing. + # Otherwise padding doesn't make sense. + if u.grid.shape[axis] > u.shape[axis] + 1: + raise ValueError('For a dirichlet cell-face boundary condition, ' + + 'the GridArray has more than 1 grid point missing.') + elif u.grid.shape[axis] == u.shape[axis] + 1 and not np.isclose( + u.offset[axis], 1): + raise ValueError('For a dirichlet cell-face boundary condition, ' + + 'the GridArray has more than 1 grid point missing.') + + def _needs_pad_with_boundary_value(): + if (np.isclose(u.offset[axis], 0) and + width > 0) or (np.isclose(u.offset[axis], 1) and width < 0): + return True + elif u.grid.shape[axis] == u.shape[axis] + 1: + return True + else: + return False + + if _needs_pad_with_boundary_value(): + if np.isclose(abs(width), 1): + data = jnp.pad( + data, + full_padding, + mode='constant', + constant_values=self.bc_values) + elif abs(width) > 1: + bc_padding, _, _ = make_padding(int(width / + abs(width))) # makes it 1 pad # subtract the padded cell - full_padding_past_bc = [(0, 0)] * u.grid.ndim - full_padding_past_bc[axis] = tuple( - pad - 1 if pad > 0 else 0 for pad in full_padding[axis]) + full_padding_past_bc, _, _ = make_padding( + (abs(width) - 1) * int(width / abs(width))) # makes it 1 pad # here we are adding 0 boundary cell with 0 value expanded_data = jnp.pad( data, bc_padding, mode='constant', constant_values=(0, 0)) @@ -250,42 +235,33 @@ def make_padding(width): mode='constant', constant_values=tuple(padding_values)) - jnp.pad( expanded_data, full_padding_past_bc, mode='reflect') - elif mode == Padding.EXTEND: - data = jnp.pad( - data, - full_padding, - mode='constant', - constant_values=self.bc_values) - else: - raise NotImplementedError(f'Mode {mode} is not implemented yet.') + else: # dirichlet cell-face aligned + padding_values = list(self.bc_values) + padding_values[axis] = [pad / 2 for pad in padding_values[axis]] + data = 2 * jnp.pad( + data, + full_padding, + mode='constant', + constant_values=tuple(padding_values)) - jnp.pad( + data, full_padding, mode='reflect') else: raise ValueError('expected offset to be an edge or cell center, got ' f'offset[axis]={u.offset[axis]}') elif bc_type == BCType.NEUMANN: - if not np.isclose(u.offset[axis] % 1, 0.5): - raise ValueError( - 'expected offset to be cell center for neumann bc, got ' - f'offset[axis]={u.offset[axis]}') + # for neumann, all grid points must be there. + # Otherwise padding doesn't make sense. + if u.grid.shape[axis] > u.shape[axis]: + raise ValueError('the GridArray shape does not match the grid.') + if not (np.isclose(u.offset[axis] % 1, 0) or + np.isclose(u.offset[axis] % 1, 0.5)): + raise ValueError('expected offset to be an edge or cell center, got ' + f'offset[axis]={u.offset[axis]}') else: # When the data is cell-centered, computes the backward difference. - - # if only one value is needed, no mode is necessary. Default mode is - # provided, although all modes would return the same values. - if np.isclose(sum(full_padding[axis]), 1) or np.isclose( - sum(full_padding[axis]), 0): - np_mode = 'symmetric' - elif mode == Padding.MIRROR: - np_mode = 'symmetric' - elif mode == Padding.EXTEND: - np_mode = 'edge' - else: - raise NotImplementedError(f'Mode {mode} is not implemented yet.') - # ensures that finite_differences.backward_difference satisfies the - # boundary condition. - derivative_direction = float(width // max(1, abs(width))) + # When the data is on cell edges, boundary is set such that + # (u_last_interior - u_boundary)/grid_step = neumann_bc_value. data = ( - jnp.pad(data, full_padding, mode=np_mode) - - derivative_direction * u.grid.step[axis] * + jnp.pad(data, full_padding, mode='edge') + u.grid.step[axis] * (jnp.pad(data, full_padding, mode='constant') - jnp.pad( data, full_padding, @@ -324,110 +300,33 @@ def _trim( offset[axis] += padding[0] return GridArray(data, tuple(offset), u.grid) - def _trim_padding(self, - u: grids.GridArray, - axis: int = 0, - trim_side: str = 'both'): - """Trims padding from a GridArray along axis and returns the array interior. + def _trim_padding(self, u: grids.GridArray, axis=0): + """Trim all padding from a GridArray. Args: u: a `GridArray` object. axis: axis to trim along. - trim_side: if 'both', trims both sides. If 'right', trims the right side. - If 'left', the left side. Returns: - Trimmed array, shrunk along the indicated axis side. + Trimmed array, shrunk along the indicated axis to match + u.grid.shape[axis]. """ padding = (0, 0) - if u.shape[axis] >= u.grid.shape[axis]: + if u.shape[axis] > u.grid.shape[axis]: # number of cells that were padded on the left negative_trim = 0 - if u.offset[axis] <= 0 and (trim_side == 'both' or trim_side == 'left'): - negative_trim = -math.ceil(-u.offset[axis]) - # periodic is a special case. Shifted data might still contain all the - # information. - if self.types[axis][0] == BCType.PERIODIC: - negative_trim = max(negative_trim, u.grid.shape[axis] - u.shape[axis]) - # for both DIRICHLET and NEUMANN cases the value on grid.domain[0] is - # a dependent value. - elif np.isclose(u.offset[axis] % 1, 0): - negative_trim -= 1 + if u.offset[axis] < 0: + negative_trim = -round(-u.offset[axis]) u = self._trim(u, negative_trim, axis) # number of cells that were padded on the right - positive_trim = 0 - if (trim_side == 'right' or trim_side == 'both'): - # periodic is a special case. Boundary on one side depends on the other - # side. - if self.types[axis][1] == BCType.PERIODIC: - positive_trim = max(u.shape[axis] - u.grid.shape[axis], 0) - else: - # for other cases, where to trim depends only on the boundary type - # and data offset. - last_u_offset = u.shape[axis] + u.offset[axis] - 1 - boundary_offset = u.grid.shape[axis] - if last_u_offset >= boundary_offset: - positive_trim = math.ceil(last_u_offset - boundary_offset) - if self.types[axis][1] == BCType.DIRICHLET and np.isclose( - u.offset[axis] % 1, 0): - positive_trim += 1 + positive_trim = u.shape[axis] - u.grid.shape[axis] if positive_trim > 0: u = self._trim(u, positive_trim, axis) # combining existing padding with new padding - padding = (-negative_trim, positive_trim) + padding = (negative_trim, positive_trim) return u, padding - def pad(self, - u: GridArray, - width: Union[Tuple[int, int], int], - axis: int, - mode: Optional[str] = Padding.EXTEND, - ) -> GridArray: - """Wrapper for _pad. - - Args: - u: a `GridArray` object. - width: number of elements to pad along axis. If width is an int, use - negative value for lower boundary or positive value for upper boundary. - If a tuple, pads with width[0] on the left and width[1] on the right. - axis: axis to pad along. - mode: type of padding to use in non-periodic case. - Mirror mirrors the array values across the boundary. - Extend extends the last well-defined array value past the boundary. - - Returns: - Padded array, elongated along the indicated axis. - """ - _ = self._is_aligned(u, axis) - if isinstance(width, int): - u = self._pad(u, width, axis, mode=mode) - else: - u = self._pad(u, -width[0], axis, mode=mode) - u = self._pad(u, width[1], axis, mode=mode) - return u - - def pad_all(self, - u: GridArray, - width: Tuple[Tuple[int, int], ...], - mode: Optional[str] = Padding.EXTEND - ) -> GridArray: - """Pads along all axes with pad width specified by width tuple. - - Args: - u: a `GridArray` object. - width: Tuple of padding width for each side for each axis. - mode: type of padding to use in non-periodic case. - Mirror mirrors the array values across the boundary. - Extend extends the last well-defined array value past the boundary. - - Returns: - Padded array, elongated along all axes. - """ - for axis in range(u.grid.ndim): - u = self.pad(u, width[axis], axis, mode=mode) - return u - - def values( # pytype: disable=signature-mismatch # overriding-parameter-count-checks + def values( self, axis: int, grid: grids.Grid) -> Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]: """Returns boundary values on the grid along axis. @@ -461,17 +360,24 @@ def trim_boundary(self, u: grids.GridArray) -> grids.GridArray: A GridArray shrunk along certain dimensions. """ for axis in range(u.grid.ndim): - _ = self._is_aligned(u, axis) u, _ = self._trim_padding(u, axis) + if u.shape != u.grid.shape: + raise ValueError('the GridArray has already been trimmed.') + for axis in range(u.grid.ndim): + if np.isclose(u.offset[axis], + 0.0) and self.types[axis][0] == BCType.DIRICHLET: + u = self._trim(u, -1, axis) + elif np.isclose(u.offset[axis], + 1.0) and self.types[axis][1] == BCType.DIRICHLET: + u = self._trim(u, 1, axis) return u def pad_and_impose_bc( self, u: grids.GridArray, - offset_to_pad_to: Optional[Tuple[float,...]] = None, - mode: Optional[str] = Padding.EXTEND, - ) -> grids.GridVariable: - """Returns GridVariable with correct boundary values. + offset_to_pad_to: Optional[Tuple[float, + ...]] = None) -> grids.GridVariable: + """Returns GridVariable with correct boundary condition. Some grid points of GridArray might coincide with boundary. This ensures that the GridVariable.array agrees with GridVariable.bc. @@ -481,23 +387,19 @@ def pad_and_impose_bc( offset_to_pad_to: a Tuple of desired offset to pad to. Note that if the function is given just an interior array in dirichlet case, it can pad to both 0 offset and 1 offset. - mode: type of padding to use in non-periodic case. - Mirror mirrors the flow across the boundary. - Extend extends the last well-defined value past the boundary. Returns: - A GridVariable that has correct boundary values. + A GridVariable that has correct boundary. """ if offset_to_pad_to is None: offset_to_pad_to = u.offset for axis in range(u.grid.ndim): - _ = self._is_aligned(u, axis) if self.types[axis][0] == BCType.DIRICHLET and np.isclose( u.offset[axis], 1.0): if np.isclose(offset_to_pad_to[axis], 1.0): - u = self._pad(u, 1, axis, mode=mode) + u = self._pad(u, 1, axis) elif np.isclose(offset_to_pad_to[axis], 0.0): - u = self._pad(u, -1, axis, mode=mode) + u = self._pad(u, -1, axis) return grids.GridVariable(u, self) def impose_bc(self, u: grids.GridArray) -> grids.GridVariable: @@ -509,16 +411,17 @@ def impose_bc(self, u: grids.GridArray) -> grids.GridVariable: u: a `GridArray` object. Returns: - A GridVariable that has correct boundary values and is restricted to the - domain. + A GridVariable that has correct boundary. """ offset = u.offset - u = self.trim_boundary(u) + if u.shape == u.grid.shape: + u = self.trim_boundary(u) return self.pad_and_impose_bc(u, offset) trim = _trim + pad = _pad - +@register_pytree_node_class class HomogeneousBoundaryConditions(ConstantBoundaryConditions): """Boundary conditions for a PDE variable. @@ -538,9 +441,106 @@ def __init__(self, types: Sequence[Tuple[str, str]]): ndim = len(types) values = ((0.0, 0.0),) * ndim - super(HomogeneousBoundaryConditions, self).__init__(types, values) + bc_fn = lambda x: x + time_stamp = 0.0 + super(HomogeneousBoundaryConditions, self).__init__(time_stamp, values,types,bc_fn) + + +@register_pytree_node_class +class TimeDependentBoundaryConditions(ConstantBoundaryConditions): + """Boundary conditions for a PDE variable. + + Example usage: + grid = Grid((10, 10)) + array = GridArray(np.zeros((10, 10)), offset=(0.5, 0.5), grid) + bc = ConstantBoundaryConditions(((BCType.PERIODIC, BCType.PERIODIC), + (BCType.DIRICHLET, BCType.DIRICHLET))) + u = GridVariable(array, bc) + + Attributes: + types: `types[i]` is a tuple specifying the lower and upper BC types for + dimension `i`. + """ + + def __init__(self, types: Sequence[Tuple[str, str]],values: Sequence[Tuple[Optional[float], Optional[float]]],boundary_fn: Callable[..., Optional[float]],time_stamp: Optional[float]): + + #ndim = len(types) + #values = ((0.0, 0.0),) * ndim + + super(TimeDependentBoundaryConditions, self).__init__(types, values,boundary_fn,time_stamp) + + def tree_flatten(self): + """Returns flattening recipe for GridVariable JAX pytree.""" + children = (self.bc_values,) + aux_data = (self.time_stamp,self.types,self.boundary_fn,) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for GridVariable JAX pytree.""" + return cls(*children, *aux_data) + + +def boundary_function(t): + A=1 + B = 1 + freq = 1 + return 1+0*(A*jnp.cos(freq*t)+B*jnp.sin(freq*t)) + +def Reserve_BC(all_variable: particle_class.All_Variables,step_time: float) -> particle_class.All_Variables: + v = all_variable.velocity + particles = all_variable.particles + pressure = all_variable.pressure + Drag = all_variable.Drag + Step_count = all_variable.Step_count + MD_var = all_variable.MD_var + bcfn = v[0].bc.boundary_fn + bcfny = v[1].bc.boundary_fn + + dt = step_time + ts = v[0].bc.time_stamp + dt# v[0].bc.time_stamp #v[0].bc.update_bc_(v[0].bc.time_stamp,dt) + #ts = dt + vx_bc = ((bcfn[0](ts),bcfn[1](0.0)),(bcfn[2](ts),bcfn[3](0.0))) + vy_bc = ((bcfny[0](ts),bcfny[1](0.0)),(bcfny[2](ts),bcfny[3](0.0))) + #vel_bc =(Moving_wall_boundary_conditions(ndim=2,bc_vals=vx_bc,time_stamp=ts,bc_fn=bcfn),Moving_wall_boundary_conditions(ndim=2,bc_vals=vy_bc,time_stamp=ts,bc_fn=bcfn)) + vel_bc = (ConstantBoundaryConditions(values=vx_bc,time_stamp=ts,types=v[0].bc.types,boundary_fn=bcfn), + ConstantBoundaryConditions(values=vy_bc,time_stamp=ts,types=v[1].bc.types,boundary_fn=bcfny)) + #return v + #return tuple(grids.GridVariable(u.array, u.bc) for u in v) + + v_updated = tuple( + grids.GridVariable(u.array, bc) for u, bc in zip(v, vel_bc)) + return particle_class.All_Variables(particles,v_updated,pressure,Drag,Step_count,MD_var) + + + +def update_BC(all_variable: particle_class.All_Variables,step_time: float) -> particle_class.All_Variables: + v = all_variable.velocity + particles = all_variable.particles + pressure = all_variable.pressure + Drag = all_variable.Drag + Step_count = all_variable.Step_count + MD_var = all_variable.MD_var + bcfn = v[0].bc.boundary_fn + bcfny = v[1].bc.boundary_fn + + dt = step_time + ts = v[0].bc.time_stamp + dt# v[0].bc.time_stamp #v[0].bc.update_bc_(v[0].bc.time_stamp,dt) + #ts = dt + vx_bc = ((bcfn[0](ts),bcfn[1](ts)),(bcfn[2](ts),bcfn[3](ts))) + vy_bc = ((bcfny[0](ts),bcfny[1](ts)),(bcfny[2](ts),bcfny[3](ts))) + #vel_bc =(Moving_wall_boundary_conditions(ndim=2,bc_vals=vx_bc,time_stamp=ts,bc_fn=bcfn),Moving_wall_boundary_conditions(ndim=2,bc_vals=vy_bc,time_stamp=ts,bc_fn=bcfn)) + vel_bc = (ConstantBoundaryConditions(values=vx_bc,time_stamp=ts,types=v[0].bc.types,boundary_fn=bcfn), + ConstantBoundaryConditions(values=vy_bc,time_stamp=ts,types=v[1].bc.types,boundary_fn=bcfny)) + #return v + #return tuple(grids.GridVariable(u.array, u.bc) for u in v) + + v_updated = tuple( + grids.GridVariable(u.array, bc) for u, bc in zip(v, vel_bc)) + return particle_class.All_Variables(particles,v_updated,pressure,Drag,Step_count,MD_var) + # Convenience utilities to ease updating of BoundaryConditions implementation def periodic_boundary_conditions(ndim: int) -> ConstantBoundaryConditions: """Returns periodic BCs for a variable with `ndim` spatial dimension.""" @@ -548,6 +548,20 @@ def periodic_boundary_conditions(ndim: int) -> ConstantBoundaryConditions: ((BCType.PERIODIC, BCType.PERIODIC),) * ndim) + +def Radom_velocity_conditions(ndim: int) -> ConstantBoundaryConditions: + """Returns periodic BCs for a variable with `ndim` spatial dimension.""" + + values = ((0.0, 0.0),) * ndim + bc_fn = lambda x: x + time_stamp = 0.0 + return Moving_wall_boundary_conditions( + ndim, + bc_vals=values, + time_stamp=time_stamp, + bc_fn=bc_fn,) + + def dirichlet_boundary_conditions( ndim: int, bc_vals: Optional[Sequence[Tuple[float, float]]] = None, @@ -617,6 +631,76 @@ def channel_flow_boundary_conditions( return ConstantBoundaryConditions(bc_type, bc_vals) + + +def Moving_wall_boundary_conditions( + ndim: int, + bc_vals: Optional[Sequence[Tuple[float, float]]], + time_stamp: Optional[float], + bc_fn: Callable[...,Optional[float]], + +) -> ConstantBoundaryConditions: + """Returns BCs periodic for dimension 0 and Dirichlet for dimension 1. + + Args: + ndim: spatial dimension. + bc_fn: function describing the time dependent boundary condition + bc_vals: A tuple of lower and upper boundary values for each dimension. + If None, returns Homogeneous BC. For periodic dimensions the lower, upper + boundary values should be (None, None). + + Returns: + BoundaryCondition instance. + """ + bc_type = ((BCType.PERIODIC, BCType.PERIODIC), + (BCType.DIRICHLET, BCType.DIRICHLET)) + for _ in range(ndim - 2): + bc_type += ((BCType.PERIODIC, BCType.PERIODIC),) + + + + + return ConstantBoundaryConditions(values=bc_vals,time_stamp=time_stamp,types=bc_type,boundary_fn=bc_fn) + + +def Far_field_boundary_conditions( + ndim: int, + bc_vals: Optional[Sequence[Tuple[float, float]]], + time_stamp: Optional[float], + bc_fn: Callable[...,Optional[float]], + +) -> ConstantBoundaryConditions: + """Returns BCs periodic for dimension 0 and Dirichlet for dimension 1. + + Args: + ndim: spatial dimension. + bc_fn: function describing the time dependent boundary condition + bc_vals: A tuple of lower and upper boundary values for each dimension. + If None, returns Homogeneous BC. For periodic dimensions the lower, upper + boundary values should be (None, None). + + Returns: + BoundaryCondition instance. + """ + bc_type = ((BCType.DIRICHLET, BCType.DIRICHLET), + (BCType.DIRICHLET, BCType.DIRICHLET)) + for _ in range(ndim - 2): + bc_type += ((BCType.DIRICHLET, BCType.DIRICHLET),) + + + + + return ConstantBoundaryConditions(values=bc_vals,time_stamp=time_stamp,types=bc_type,boundary_fn=bc_fn) + +def find_extremum(fn,extrema,i_guess): + if extrema == 'maximum': + direc = -1 + elif extrema == 'minimum': + direc = 1 + else: + raise ValueError('No extrema was correctly identified. For maximum, type "maiximum". For minimization, type "minimum". ') + return fn(scipy.optimize.fmin(lambda x: direc*fn(x), i_guess)) + def periodic_and_neumann_boundary_conditions( bc_vals: Optional[Tuple[float, float]] = None) -> ConstantBoundaryConditions: @@ -694,7 +778,7 @@ def consistent_boundary_conditions(*arrays: GridVariable) -> Tuple[str, ...]: *arrays: a list of gridvariables. Returns: - a list of types of boundaries corresponding to each axis if + a list of types of boundaries corresponding to each axis if they are consistent. """ bc_types = [] @@ -710,101 +794,86 @@ def consistent_boundary_conditions(*arrays: GridVariable) -> Tuple[str, ...]: return tuple(bc_types) -def get_pressure_bc_from_velocity( - v: GridVariableVector) -> HomogeneousBoundaryConditions: +def get_pressure_bc_from_velocity(v: GridVariableVector) -> BoundaryConditions: """Returns pressure boundary conditions for the specified velocity.""" # assumes that if the boundary is not periodic, pressure BC is zero flux. velocity_bc_types = consistent_boundary_conditions(*v) pressure_bc_types = [] + bc_value = ((0.0,0.0),(0.0,0.0)) + Bc_f = v[0].bc.boundary_fn for velocity_bc_type in velocity_bc_types: if velocity_bc_type == 'periodic': pressure_bc_types.append((BCType.PERIODIC, BCType.PERIODIC)) else: pressure_bc_types.append((BCType.NEUMANN, BCType.NEUMANN)) - return HomogeneousBoundaryConditions(pressure_bc_types) + return ConstantBoundaryConditions(values=bc_value,time_stamp=2.0,types=pressure_bc_types,boundary_fn=Bc_f) + def get_advection_flux_bc_from_velocity_and_scalar( u: GridVariable, c: GridVariable, - flux_direction: int) -> ConstantBoundaryConditions: + flux_direction: int) -> BoundaryConditions: """Returns advection flux boundary conditions for the specified velocity. - Infers advection flux boundary condition in flux direction from scalar c and velocity u in direction flux_direction. - The flux boundary condition should be used only to compute divergence. If the boundaries are periodic, flux is periodic. In nonperiodic case, flux boundary parallel to flux direction is homogeneous dirichlet. In nonperiodic case if flux direction is normal to the wall, the - function supports 2 cases: - 1) Nonporous boundary, corresponding to homogeneous flux bc. - 2) Pourous boundary with constant flux, corresponding to - both the velocity and scalar with Homogeneous Neumann bc. - - This function supports only these cases because all other cases result in - time dependent flux boundary condition. - + function checks that the boundary needed is nonporous and returns the + homogeneous bc. Otherwise throws an error. Args: u: velocity component in flux_direction. c: scalar to advect. flux_direction: direction of velocity. - Returns: BoundaryCondition instance for advection flux of c in flux_direction. """ # only no penetration and periodic boundaries are supported. flux_bc_types = [] - flux_bc_values = [] - if not isinstance(u.bc, HomogeneousBoundaryConditions): + if not isinstance(u.bc, ConstantBoundaryConditions): raise NotImplementedError( - f'Flux boundary condition is not implemented for velocity with {u.bc}') + f'Flux boundary condition is not implemented for {u.bc, c.bc}') for axis in range(c.grid.ndim): if u.bc.types[axis][0] == 'periodic': flux_bc_types.append((BCType.PERIODIC, BCType.PERIODIC)) - flux_bc_values.append((None, None)) elif flux_direction != axis: - # This is not technically correct. Flux boundary condition in most cases - # is a time dependent function of the current values of the scalar - # and velocity. However, because flux is used only to take divergence, the - # boundary condition on the flux along the boundary parallel to the flux - # direction has no influence on the computed divergence because the - # boundary condition only alters ghost cells, while divergence is computed - # on the interior. - # To simplify the code and allow for flux to be wrapped in gridVariable, - # we are setting the boundary to homogeneous dirichlet. - # Note that this will not work if flux is used in any other capacity but - # to take divergence. flux_bc_types.append((BCType.DIRICHLET, BCType.DIRICHLET)) - flux_bc_values.append((0.0, 0.0)) + elif (u.bc.types[axis][0] == BCType.DIRICHLET and + u.bc.types[axis][1] == BCType.DIRICHLET and + u.bc.bc_values[axis][0] == 0.0 and u.bc.bc_values[axis][1] == 0.0): + flux_bc_types.append((BCType.DIRICHLET, BCType.DIRICHLET)) else: - flux_bc_types_ax = [] - flux_bc_values_ax = [] - for i in range(2): # lower and upper boundary. - - # case 1: nonpourous boundary - if (u.bc.types[axis][i] == BCType.DIRICHLET and - u.bc.bc_values[axis][i] == 0.0): - flux_bc_types_ax.append(BCType.DIRICHLET) - flux_bc_values_ax.append(0.0) - - # case 2: zero flux boundary - elif (u.bc.types[axis][i] == BCType.NEUMANN and - c.bc.types[axis][i] == BCType.NEUMANN): - if not isinstance(c.bc, ConstantBoundaryConditions): - raise NotImplementedError( - 'Flux boundary condition is not implemented for scalar' + - f' with {c.bc}') - if not np.isclose(c.bc.bc_values[axis][i], 0.0): - raise NotImplementedError( - 'Flux boundary condition is not implemented for scalar' + - f' with {c.bc}') - flux_bc_types_ax.append(BCType.NEUMANN) - flux_bc_values_ax.append(0.0) - - # no other case is supported - else: - raise NotImplementedError( - f'Flux boundary condition is not implemented for {u.bc, c.bc}') - flux_bc_types.append(flux_bc_types_ax) - flux_bc_values.append(flux_bc_values_ax) - return ConstantBoundaryConditions(flux_bc_types, flux_bc_values) + raise NotImplementedError( + f'Flux boundary condition is not implemented for {u.bc, c.bc}') + return HomogeneousBoundaryConditions(flux_bc_types) + + +def new_periodic_boundary_conditions( + ndim: int, + bc_vals: Optional[Sequence[Tuple[float, float]]], + time_stamp: Optional[float], + bc_fn: Callable[...,Optional[float]], + +) -> ConstantBoundaryConditions: + """Returns BCs periodic for dimension 0 and Dirichlet for dimension 1. + + Args: + ndim: spatial dimension. + bc_fn: function describing the time dependent boundary condition + bc_vals: A tuple of lower and upper boundary values for each dimension. + If None, returns Homogeneous BC. For periodic dimensions the lower, upper + boundary values should be (None, None). + + Returns: + BoundaryCondition instance. + """ + bc_type = ((BCType.PERIODIC, BCType.PERIODIC), + (BCType.PERIODIC, BCType.PERIODIC)) + for _ in range(ndim - 2): + bc_type += ((BCType.PERIODIC, BCType.PERIODIC),) + + + + + return ConstantBoundaryConditions(values=bc_vals,time_stamp=time_stamp,types=bc_type,boundary_fn=bc_fn) diff --git a/jax_cfd/base/convolution_functions.py b/jax_cfd/base/convolution_functions.py new file mode 100644 index 0000000..4702a10 --- /dev/null +++ b/jax_cfd/base/convolution_functions.py @@ -0,0 +1,37 @@ +import jax +import jax.numpy as jnp + + +def delta_approx_logistjax(x,x0,w): + + return 1/(w*jnp.sqrt(2*jnp.pi))*jnp.exp(-0.5*((x-x0)/w)**2) + + + +def new_surf_fn(field,xp,yp,discrete_fn): + grid = field.grid + offset = field.offset + X,Y = grid.mesh(offset) + dx = grid.step[0] + dy = grid.step[1] + + def calc_force(xp,yp): + return jnp.sum(field.data*discrete_fn(xp,X,dx)*discrete_fn(yp,Y,dy)*dx*dy) + def foo(tree_arg): + xp,yp = tree_arg + return calc_force(xp,yp) + + def foo_pmap(tree_arg): + #print(tree_arg) + return jax.vmap(foo,in_axes=1)(tree_arg) + + divider=jax.device_count() + n = len(xp)//divider + mapped = [] + for i in range(divider): + # print(i) + mapped.append([xp[i*n:(i+1)*n],yp[i*n:(i+1)*n]]) + + U_deltas = jax.pmap(foo_pmap)(jnp.array(mapped)) + + return U_deltas.flatten() diff --git a/jax_cfd/base/diffusion.py b/jax_cfd/base/diffusion.py index eb28919..a52cbe6 100644 --- a/jax_cfd/base/diffusion.py +++ b/jax_cfd/base/diffusion.py @@ -15,15 +15,15 @@ # TODO(pnorgaard) Implement bicgstab for non-symmetric operators """Module for functionality related to diffusion.""" -from typing import Optional, Tuple +from typing import Optional -import jax.numpy as jnp import jax.scipy.sparse.linalg + from jax_cfd.base import array_utils -from jax_cfd.base import boundaries +from jax_ib.base import boundaries from jax_cfd.base import fast_diagonalization -from jax_cfd.base import finite_differences as fd -from jax_cfd.base import grids +from jax_ib.base import finite_differences as fd +from jax_ib.base import grids Array = grids.Array GridArray = grids.GridArray @@ -57,83 +57,6 @@ def stable_time_step(viscosity: float, grid: grids.Grid) -> float: return dx ** 2 / (viscosity * 2 ** ndim) -def _subtract_linear_part_dirichlet( - c_data: Array, - grid: grids.Grid, - axis: int, - offset: Tuple[float, float], - bc_values: Tuple[float, float], -) -> Array: - """Transforms c_data such that c_data satisfies dirichlet boundary. - - The function subtracts a linear function from c_data s.t. the returned - array has homogeneous dirichlet boundaries. Note that this assumes c_data has - constant dirichlet boundary values. - - Args: - c_data: right-hand-side of diffusion equation. - grid: grid object - axis: axis along which to impose boundary transformation - offset: offset of the right-hand-side - bc_values: boundary values along axis - - Returns: - transformed right-hand-side - """ - - def _update_rhs_along_axis(arr_1d, linear_part): - arr_1d = arr_1d - linear_part - return arr_1d - - lower_value, upper_value = bc_values - y = grid.mesh(offset)[axis][0] - one_d_grid = grids.Grid((grid.shape[axis],), domain=(grid.domain[axis],)) - y_boundary = boundaries.dirichlet_boundary_conditions(ndim=1) - y = y_boundary.trim_boundary(grids.GridArray(y, (offset[axis],), - one_d_grid)).data - domain_length = (grid.domain[axis][1] - grid.domain[axis][0]) - domain_start = grid.domain[axis][0] - linear_part = lower_value + (upper_value - lower_value) * ( - y - domain_start) / domain_length - c_data = jnp.apply_along_axis( - _update_rhs_along_axis, axis, c_data, linear_part) - return c_data - - -def _rhs_transform( - u: grids.GridArray, - bc: boundaries.BoundaryConditions, -) -> Array: - """Transforms the RHS of diffusion equation. - - In case of constant dirichlet boundary conditions for heat equation - the linear term is subtracted. See diffusion.solve_fast_diag. - - Args: - u: a GridArray that solves ∇²x = ∇²u for x. - bc: specifies boundary of u. - - Returns: - u' s.t. u = u' + w where u' has 0 dirichlet bc and w is linear. - """ - if not isinstance(bc, boundaries.ConstantBoundaryConditions): - raise NotImplementedError( - f'transformation cannot be done for this {bc}.') - u_data = u.data - for axis in range(u.grid.ndim): - for i, _ in enumerate(['lower', 'upper']): # lower and upper boundary - if bc.types[axis][i] == boundaries.BCType.DIRICHLET: - bc_values = [0., 0.] - bc_values[i] = bc.bc_values[axis][i] - u_data = _subtract_linear_part_dirichlet(u_data, u.grid, axis, u.offset, - bc_values) - elif bc.types[axis][i] == boundaries.BCType.NEUMANN: - if any(bc.bc_values[axis]): - raise NotImplementedError( - 'transformation is not implemented for inhomogeneous Neumann bc.') - return u_data - - def solve_cg(v: GridVariableVector, nu: float, dt: float, @@ -163,50 +86,30 @@ def cg(b: GridArray, x0: GridArray) -> GridArray: return tuple(grids.GridVariable(solve_component(u), u.bc) for u in v) -def solve_fast_diag( - v: GridVariableVector, - nu: float, - dt: float, - implementation: Optional[str] = None, -) -> GridVariableVector: +def solve_fast_diag(v: GridVariableVector, + nu: float, + dt: float, + implementation: Optional[str] = None) -> GridVariableVector: """Solve for diffusion using the fast diagonalization approach.""" # We reuse eigenvectors from the Laplacian and transform the eigenvalues # because this is better conditioned than directly diagonalizing 1 - ν Δt ∇² # when ν Δt is small. + if not boundaries.has_all_periodic_boundary_conditions(*v): + raise ValueError('solve_fast_diag() expects periodic BC') + grid = grids.consistent_grid(*v) + laplacians = list(map(array_utils.laplacian_matrix, grid.shape, grid.step)) + + # Transform the eigenvalues to implement (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) def func(x): dt_nu_x = (dt * nu) * x return dt_nu_x / (1 - dt_nu_x) - # Compute (1 - ν Δt ∇²)⁻¹ u as u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u, for less - # error when ν Δt is small. - # If dirichlet bc are supplied: only works for dirichlet bc that are linear - # functions on the boundary. Then u = u' + w where u' has 0 dirichlet bc and - # w is linear. Then u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u = u + - # (1 - ν Δt ∇²)⁻¹(ν Δt ∇²)u'. The function _rhs_transform subtracts - # the linear part s.t. fast_diagonalization solves - # u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u'. - v_diffused = list() - if boundaries.has_all_periodic_boundary_conditions(*v): - circulant = True - else: - circulant = False - # only matmul implementation supports non-circulant matrices - implementation = 'matmul' - for u in v: - laplacians = array_utils.laplacian_matrix_w_boundaries( - u.grid, u.offset, u.bc) - op = fast_diagonalization.transform( - func, - laplacians, - v[0].dtype, - hermitian=True, - circulant=circulant, - implementation=implementation) - u_interior = u.bc.trim_boundary(u.array) - u_interior_transformed = _rhs_transform(u_interior, u.bc) - u_dt_diffused = grids.GridArray( - op(u_interior_transformed), u_interior.offset, u_interior.grid) - u_diffused = u_interior + u_dt_diffused - u_diffused = u.bc.pad_and_impose_bc(u_diffused, offset_to_pad_to=u.offset) - v_diffused.append(u_diffused) - return tuple(v_diffused) + # Note: this assumes that each velocity field has the same shape and dtype. + op = fast_diagonalization.transform( + func, laplacians, v[0].dtype, + hermitian=True, circulant=True, implementation=implementation) + + # Compute (1 - ν Δt ∇²)⁻¹ u as u + (1 - ν Δt ∇²)⁻¹ (ν Δt ∇²) u, for less error + # when ν Δt is small. + return tuple(grids.GridVariable(u.array + grids.applied(op)(u.array), u.bc) + for u in v) diff --git a/jax_cfd/base/equations.py b/jax_cfd/base/equations.py index af0954e..7d14158 100644 --- a/jax_cfd/base/equations.py +++ b/jax_cfd/base/equations.py @@ -1,33 +1,22 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Examples of defining equations.""" + + import functools from typing import Callable, Optional import jax import jax.numpy as jnp -from jax_cfd.base import advection -from jax_cfd.base import diffusion -from jax_cfd.base import grids -from jax_cfd.base import pressure -from jax_cfd.base import time_stepping +from jax_ib.base import advection +from jax_ib.base import diffusion +from jax_ib.base import grids +from jax_ib.base import pressure +from jax_cfd.base import pressure as pressureCFD +from jax_ib.base import time_stepping +from jax_ib.base import boundaries +from jax_ib.base import finite_differences import tree_math - -# Specifying the full signatures of Callable would get somewhat onerous -# pylint: disable=g-bare-generic +from jax_ib.base import particle_class +from jax_cfd.base import equations as equationsCFD GridArray = grids.GridArray GridArrayVector = grids.GridArrayVector @@ -36,38 +25,14 @@ ConvectFn = Callable[[GridVariableVector], GridArrayVector] DiffuseFn = Callable[[GridVariable, float], GridArray] ForcingFn = Callable[[GridVariableVector], GridArrayVector] +BCFn = Callable[[particle_class.All_Variables, float], particle_class.All_Variables] +BCFn_new = Callable[[GridVariableVector, float], GridVariableVector] +IBMFn = Callable[[particle_class.All_Variables, float], GridVariableVector] +GradPFn = Callable[[GridVariable], GridArrayVector] +PosFn = Callable[[particle_class.All_Variables, float], particle_class.All_Variables] -def sum_fields(*args): - return jax.tree.map(lambda *a: sum(a), *args) - - -def stable_time_step( - max_velocity: float, - max_courant_number: float, - viscosity: float, - grid: grids.Grid, - implicit_diffusion: bool = False, -) -> float: - """Calculate a stable time step for Navier-Stokes.""" - dt = advection.stable_time_step(max_velocity, max_courant_number, grid) - if not implicit_diffusion: - diffusion_dt = diffusion.stable_time_step(viscosity, grid) - if diffusion_dt < dt: - raise ValueError(f'stable time step for diffusion is smaller than ' - f'the chosen timestep: {diffusion_dt} vs {dt}') - return dt - - -def dynamic_time_step(v: GridVariableVector, - max_courant_number: float, - viscosity: float, - grid: grids.Grid, - implicit_diffusion: bool = False) -> float: - """Pick a dynamic time-step for Navier-Stokes based on stable advection.""" - v_max = jnp.sqrt(jnp.max(sum(u.data ** 2 for u in v))) - return stable_time_step( # pytype: disable=wrong-arg-types # jax-types - v_max, max_courant_number, viscosity, grid, implicit_diffusion) +DragFn = Callable[[particle_class.All_Variables], particle_class.All_Variables] def _wrap_term_as_vector(fun, *, name): @@ -82,6 +47,7 @@ def navier_stokes_explicit_terms( convect: Optional[ConvectFn] = None, diffuse: DiffuseFn = diffusion.diffuse, forcing: Optional[ForcingFn] = None, + ) -> Callable[[GridVariableVector], GridVariableVector]: """Returns a function that performs a time step of Navier Stokes.""" del grid # unused @@ -107,6 +73,7 @@ def _explicit_terms(v): dv_dt += diffusion_(v, viscosity / density) if forcing is not None: dv_dt += forcing(v) / density + return dv_dt def explicit_terms_with_same_bcs(v): @@ -116,17 +83,129 @@ def explicit_terms_with_same_bcs(v): return explicit_terms_with_same_bcs -# TODO(shoyer): rename this to explicit_diffusion_navier_stokes -def semi_implicit_navier_stokes( + + + +def explicit_Reserve_BC( + ReserveBC: BCFn , + step_time: float, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def Reserve_boundary(v, *args): + return ReserveBC(v, *args) + Reserve_bc_ = _wrap_term_as_vector(Reserve_boundary, name='Reserve_BC') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _Reserve_bc(v): + + return Reserve_bc_(v,step_time) + + return _Reserve_bc + +def explicit_update_BC( + updateBC: BCFn , + step_time: float, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def Update_boundary(v, *args): + return updateBC(v, *args) + Update_bc_ = _wrap_term_as_vector(Update_boundary, name='Update_BC') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _Update_bc(v): + + return Update_bc_(v,step_time) + + return _Update_bc + + +def explicit_IBM_Force( + cal_IBM_force: IBMFn , + step_time: float, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def IBM_FORCE(v, *args): + return cal_IBM_force(v, *args) + IBM_FORCE_ = _wrap_term_as_vector(IBM_FORCE, name='IBM_FORCE') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _IBM_FORCE(v): + + return IBM_FORCE_(v,step_time) + + return _IBM_FORCE + + + +def explicit_Update_position( + cal_Update_Position: PosFn , + step_time: float, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def Update_Position(v, *args): + return cal_Update_Position(v, *args) + Update_Position_ = _wrap_term_as_vector(Update_Position, name='Update_Position') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _Update_Position(v): + + return Update_Position_(v,step_time) + + return _Update_Position + + +def explicit_Calc_Drag( + cal_Drag: DragFn , + step_time: float, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def Calculate_Drag(v, *args): + return cal_Drag(v, *args) + Calculate_Drag_ = _wrap_term_as_vector(Calculate_Drag, name='Calculate_Drag') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _Calculate_Drag(v): + + return Calculate_Drag_(v,step_time) + + return _Calculate_Drag + +def explicit_Pressure_Gradient( + cal_Pressure_Grad: GradPFn, +) -> Callable[[GridVariableVector], GridVariableVector]: + + def Pressure_Grad(v): + return cal_Pressure_Grad(v) + Pressure_Grad_ = _wrap_term_as_vector(Pressure_Grad, name='Pressure_Grad') + + @tree_math.wrap + # @functools.partial(jax.named_call, name='master_BC_fn') + def _Pressure_Grad(v): + + return Pressure_Grad_(v) + + return _Pressure_Grad + +def semi_implicit_navier_stokes_timeBC( density: float, viscosity: float, dt: float, grid: grids.Grid, convect: Optional[ConvectFn] = None, diffuse: DiffuseFn = diffusion.diffuse, - pressure_solve: Callable = pressure.solve_fast_diag, + pressure_solve: Callable = pressureCFD.solve_fast_diag, forcing: Optional[ForcingFn] = None, - time_stepper: Callable = time_stepping.forward_euler, + time_stepper: Callable = time_stepping.forward_euler_updated, + IBM_forcing: IBMFn=None, + Updating_Position:PosFn=None , + Pressure_Grad:GradPFn=finite_differences.forward_difference, + Drag_fn:DragFn=None, + ) -> Callable[[GridVariableVector], GridVariableVector]: """Returns a function that performs a time step of Navier Stokes.""" @@ -139,57 +218,63 @@ def semi_implicit_navier_stokes( diffuse=diffuse, forcing=forcing) - pressure_projection = jax.named_call(pressure.projection, name='pressure') - + pressure_projection = jax.named_call(pressure.projection_and_update_pressure, name='pressure') + Reserve_BC = explicit_Reserve_BC(ReserveBC = boundaries.Reserve_BC,step_time = dt) + update_BC = explicit_update_BC(updateBC = boundaries.update_BC,step_time = dt) + IBM_force = explicit_IBM_Force(cal_IBM_force = IBM_forcing,step_time = dt) + Update_Position = explicit_Update_position(cal_Update_Position = Updating_Position,step_time = dt) + Pressure_Grad = explicit_Pressure_Gradient(cal_Pressure_Grad = Pressure_Grad) + Calculate_Drag = explicit_Calc_Drag(cal_Drag = Drag_fn,step_time = dt) + #jax.named_call(boundaries.update_BC, name='Update_BC') # TODO(jamieas): Consider a scheme where pressure calculations and # advection/diffusion are staggered in time. - ode = time_stepping.ExplicitNavierStokesODE( + ode = time_stepping.ExplicitNavierStokesODE_BCtime( explicit_terms, - lambda v: pressure_projection(v, pressure_solve) + lambda v: pressure_projection(v, pressure_solve), + update_BC, + Reserve_BC, + IBM_force, + Update_Position, + Pressure_Grad, + Calculate_Drag, ) step_fn = time_stepper(ode, dt) return step_fn -def implicit_diffusion_navier_stokes( +def semi_implicit_navier_stokes_penalty( density: float, viscosity: float, dt: float, grid: grids.Grid, convect: Optional[ConvectFn] = None, - diffusion_solve: Callable = diffusion.solve_fast_diag, - pressure_solve: Callable = pressure.solve_fast_diag, + diffuse: DiffuseFn = diffusion.diffuse, + pressure_solve: Callable = pressureCFD.solve_fast_diag, forcing: Optional[ForcingFn] = None, + time_stepper: Callable = time_stepping.forward_euler_penalty, ) -> Callable[[GridVariableVector], GridVariableVector]: """Returns a function that performs a time step of Navier Stokes.""" - del grid # unused - if convect is None: - def convect(v): # pylint: disable=function-redefined - return tuple( - advection.advect_van_leer_using_limiters(u, v, dt) for u in v) - convect = jax.named_call(convect, name='convection') - pressure_projection = jax.named_call(pressure.projection, name='pressure') - diffusion_solve = jax.named_call(diffusion_solve, name='diffusion') + explicit_terms = navier_stokes_explicit_terms( + density=density, + viscosity=viscosity, + dt=dt, + grid=grid, + convect=convect, + diffuse=diffuse, + forcing=forcing) - # TODO(shoyer): refactor to support optional higher-order time integators - @jax.named_call - def navier_stokes_step(v: GridVariableVector) -> GridVariableVector: - """Computes state at time `t + dt` using first order time integration.""" - convection = convect(v) - accelerations = [convection] - if forcing is not None: - # TODO(shoyer): include time in state? - f = forcing(v) - accelerations.append(tuple(f / density for f in f)) - dvdt = sum_fields(*accelerations) - # Update v by taking a time step - v = tuple( - grids.GridVariable(u.array + dudt * dt, u.bc) - for u, dudt in zip(v, dvdt)) - # Pressure projection to incompressible velocity field - v = pressure_projection(v, pressure_solve) - # Solve for implicit diffusion - v = diffusion_solve(v, viscosity, dt) - return v - return navier_stokes_step + pressure_projection = jax.named_call(pressure.projection_and_update_pressure, name='pressure') + Reserve_BC = explicit_Reserve_BC(ReserveBC = boundaries.Reserve_BC,step_time = dt) + update_BC = explicit_update_BC(updateBC = boundaries.update_BC,step_time = dt) + #jax.named_call(boundaries.update_BC, name='Update_BC') + # TODO(jamieas): Consider a scheme where pressure calculations and + # advection/diffusion are staggered in time. + ode = time_stepping.ExplicitNavierStokesODE_Penalty( + explicit_terms, + lambda v: pressure_projection(v, pressure_solve), + update_BC, + Reserve_BC, + ) + step_fn = time_stepper(ode, dt) + return step_fn diff --git a/jax_cfd/base/finite_differences.py b/jax_cfd/base/finite_differences.py index e84077a..350ffeb 100644 --- a/jax_cfd/base/finite_differences.py +++ b/jax_cfd/base/finite_differences.py @@ -32,9 +32,11 @@ import typing from typing import Optional, Sequence, Tuple -from jax_cfd.base import grids -from jax_cfd.base import interpolation +from jax_ib.base import grids +from jax_ib.base import interpolation import numpy as np +import jax +import jax.numpy as jnp GridArray = grids.GridArray GridVariable = grids.GridVariable @@ -76,7 +78,7 @@ def central_difference(u, axis=None): if axis is None: axis = range(u.grid.ndim) if not isinstance(axis, int): - return tuple(central_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations + return tuple(central_difference(u, a) for a in axis) diff = stencil_sum(u.shift(+1, axis), -u.shift(-1, axis)) return diff / (2 * u.grid.step[axis]) @@ -97,7 +99,7 @@ def backward_difference(u, axis=None): if axis is None: axis = range(u.grid.ndim) if not isinstance(axis, int): - return tuple(backward_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations + return tuple(backward_difference(u, a) for a in axis) diff = stencil_sum(u.array, -u.shift(-1, axis)) return diff / u.grid.step[axis] @@ -119,15 +121,16 @@ def forward_difference(u, axis=None): if axis is None: axis = range(u.grid.ndim) if not isinstance(axis, int): - return tuple(forward_difference(u, a) for a in axis) # pytype: disable=wrong-arg-types # always-use-return-annotations + return tuple(forward_difference(u, a) for a in axis) diff = stencil_sum(u.shift(+1, axis), -u.array) return diff / u.grid.step[axis] def laplacian(u: GridVariable) -> GridArray: """Approximates the Laplacian of `u`.""" - scales = np.square(1 / np.array(u.grid.step, dtype=u.dtype)) - result = -2 * u.array * np.sum(scales) + scales = np.square(1 / np.array(u.grid.step, dtype=u.dtype)) + #scales = jnp.square(1 / jnp.array(u.grid.step, dtype=u.dtype)) #return to np instead of jnp + result = -2 * u.array * jnp.sum(scales) for axis in range(u.grid.ndim): result += stencil_sum(u.shift(-1, axis), u.shift(+1, axis)) * scales[axis] return result @@ -166,7 +169,7 @@ def gradient_tensor(v: Sequence[GridVariable]) -> GridArrayTensor: def gradient_tensor(v): """Approximates the cell-centered gradient of `v`.""" if not isinstance(v, GridVariable): - return GridArrayTensor(np.stack([gradient_tensor(u) for u in v], axis=-1)) # pytype: disable=wrong-arg-types # always-use-return-annotations + return GridArrayTensor(np.stack([gradient_tensor(u) for u in v], axis=-1)) grad = [] for axis in range(v.grid.ndim): offset = v.offset[axis] diff --git a/jax_cfd/base/grids.py b/jax_cfd/base/grids.py index 3c87f38..fd0f032 100644 --- a/jax_cfd/base/grids.py +++ b/jax_cfd/base/grids.py @@ -18,9 +18,8 @@ import numbers import operator from typing import Any, Callable, Optional, Sequence, Tuple, Union - -import jax from jax import core +import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class import numpy as np @@ -37,6 +36,74 @@ PyTree = Any +@register_pytree_node_class +@dataclasses.dataclass +class BCArray(np.lib.mixins.NDArrayOperatorsMixin): + """Data with an alignment offset and an associated grid. + + Offset values in the range [0, 1] fall within a single grid cell. + + Examples: + offset=(0, 0) means that each point is at the bottom-left corner. + offset=(0.5, 0.5) is at the grid center. + offset=(1, 0.5) is centered on the right-side edge. + + Attributes: + data: array values. + offset: alignment location of the data with respect to the grid. + grid: the Grid associated with the array data. + dtype: type of the array data. + shape: lengths of the array dimensions. + """ + # Don't (yet) enforce any explicit consistency requirements between data.ndim + # and len(offset), e.g., so we can feel to add extra time/batch/channel + # dimensions. But in most cases they should probably match. + # Also don't enforce explicit consistency between data.shape and grid.shape, + # but similarly they should probably match. + data: Array + + + def tree_flatten(self): + """Returns flattening recipe for BCArray JAX pytree.""" + children = (self.data,) + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for BCArray JAX pytree.""" + return cls(*children) + + @property + def dtype(self): + return self.data.dtype + + @property + def shape(self) -> Tuple[int, ...]: + return self.data.shape + + _HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, + core.ShapedArray, jax.core.Tracer) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + """Define arithmetic on BCArray using NumPy's mixin.""" + for x in inputs: + if not isinstance(x, self._HANDLED_TYPES + (BCArray,)): + return NotImplemented + if method != '__call__': + return NotImplemented + try: + # get the corresponding jax.np function to the NumPy ufunc + func = getattr(jnp, ufunc.__name__) + except AttributeError: + return NotImplemented + arrays = [x.data if isinstance(x, BCArray) else x for x in inputs] + result = func(*arrays) + if isinstance(result, tuple): + return tuple(BCArray(r) for r in result) + else: + return BCArray(result) + @register_pytree_node_class @dataclasses.dataclass class GridArray(np.lib.mixins.NDArrayOperatorsMixin): @@ -84,7 +151,8 @@ def dtype(self): def shape(self) -> Tuple[int, ...]: return self.data.shape - _HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, core.ShapedArray) + _HANDLED_TYPES = (numbers.Number, np.ndarray, jax.Array, + core.ShapedArray, jax.core.Tracer) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """Define arithmetic on GridArrays using NumPy's mixin.""" @@ -102,6 +170,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): result = func(*arrays) offset = consistent_offset(*[x for x in inputs if isinstance(x, GridArray)]) grid = consistent_grid(*[x for x in inputs if isinstance(x, GridArray)]) + #grid = inputs.grid#consistent_grid(*[x for x in inputs]) if isinstance(result, tuple): return tuple(GridArray(r, offset, grid) for r in result) else: @@ -137,7 +206,7 @@ def __new__(cls, arrays): ) -@dataclasses.dataclass(init=False, frozen=True) +@dataclasses.dataclass(init=False, frozen=False) class BoundaryConditions: """Base class for boundary conditions on a PDE variable. @@ -147,12 +216,14 @@ class BoundaryConditions: """ types: Tuple[Tuple[str, str], ...] + + + def shift( self, u: GridArray, offset: int, axis: int, - mode: Optional[str] = 'extend', ) -> GridArray: """Shift an GridArray by `offset`. @@ -160,8 +231,6 @@ def shift( u: an `GridArray` object. offset: positive or negative integer offset to shift. axis: axis to shift along. - mode: specifies how to extend past the boundary/ghost cells. - Valid options contained in boundaries.Padding. Returns: A copy of `u`, shifted by `offset`. The returned `GridArray` has offset @@ -193,7 +262,6 @@ def pad( u: GridArray, width: int, axis: int, - mode: Optional[str] = 'extend', ) -> GridArray: """Returns Arrays padded according to boundary condition. @@ -202,8 +270,6 @@ def pad( width: number of elements to pad along axis. Use negative value for lower boundary or positive value for upper boundary. axis: axis to pad along. - mode: specifies how to extend past the boundary/ghost cells. - Valid options contained in boundaries.Padding. Returns: A GridArray that is elongated along axis with padded values. @@ -299,14 +365,14 @@ def __post_init__(self): def tree_flatten(self): """Returns flattening recipe for GridVariable JAX pytree.""" - children = (self.array,) - aux_data = (self.bc,) + children = (self.array,self.bc) + aux_data = None return children, aux_data @classmethod def tree_unflatten(cls, aux_data, children): """Returns unflattening recipe for GridVariable JAX pytree.""" - return cls(*children, *aux_data) + return cls(*children) @property def dtype(self): @@ -332,21 +398,18 @@ def shift( self, offset: int, axis: int, - mode: Optional[str] = 'extend', ) -> GridArray: """Shift this GridVariable by `offset`. Args: offset: positive or negative integer offset to shift. axis: axis to shift along. - mode: specifies how to extend past the boundary/ghost cells. - Valid options contained in boundaries.Padding. Returns: A copy of the encapsulated GridArray, shifted by `offset`. The returned GridArray has offset `u.offset + offset`. """ - return self.bc.shift(self.array, offset, axis, mode) + return self.bc.shift(self.array, offset, axis) def _interior_grid(self) -> Grid: """Returns only the interior grid points.""" @@ -480,13 +543,14 @@ def consistent_grid(*arrays: Union[GridArray, GridVariable]) -> Grid: raise InconsistentGridError(f'arrays do not have a unique grid: {grids}') grid, = grids return grid + #return arrays[0].grid class InconsistentBoundaryConditionsError(Exception): """Raised for cases of inconsistent bc between GridVariables.""" -def unique_boundary_conditions(*arrays: GridVariable) -> BoundaryConditions: +def consistent_boundary_conditions(*arrays: GridVariable) -> BoundaryConditions: """Returns the unique BCs, or raises InconsistentBoundaryConditionsError.""" bcs = {array.bc for array in arrays} if len(bcs) != 1: @@ -524,7 +588,7 @@ def __init__( object.__setattr__(self, 'shape', shape) if step is not None and domain is not None: - raise TypeError('cannot provide both step and domain') + raise TypeError('MODIFIED cannot provide both step and domain') elif domain is not None: if isinstance(domain, (int, float)): domain = ((0, domain),) * len(shape) @@ -537,6 +601,7 @@ def __init__( raise ValueError( f'domain is not sequence of pairs of numbers: {domain}') domain = tuple((float(lower), float(upper)) for lower, upper in domain) + else: if step is None: step = 1 @@ -579,7 +644,7 @@ def stagger(self, v: Tuple[Array, ...]) -> Tuple[GridArray, ...]: def center(self, v: PyTree) -> PyTree: """Places all arrays in the pytree `v` at the `Grid`'s cell center.""" offset = self.cell_center - return jax.tree.map(lambda u: GridArray(u, offset, self), v) + return jax.tree_map(lambda u: GridArray(u, offset, self), v) def axes(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]: """Returns a tuple of arrays containing the grid points along each axis. diff --git a/jax_cfd/base/interpolation.py b/jax_cfd/base/interpolation.py index 8ef5788..12cadb0 100644 --- a/jax_cfd/base/interpolation.py +++ b/jax_cfd/base/interpolation.py @@ -18,10 +18,11 @@ import jax import jax.numpy as jnp -from jax_cfd.base import boundaries -from jax_cfd.base import grids +from jax_ib.base import boundaries +from jax_ib.base import grids import numpy as np + Array = Union[np.ndarray, jax.Array] GridArray = grids.GridArray GridArrayVector = grids.GridArrayVector diff --git a/jax_cfd/base/kinematics.py b/jax_cfd/base/kinematics.py new file mode 100644 index 0000000..ed413a0 --- /dev/null +++ b/jax_cfd/base/kinematics.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp + + +def displacement(parameters,t): + A0,f = list(*parameters) + return jnp.array([A0/2*jnp.cos(2*jnp.pi*f*t),0.]) + + +def rotation(parameters,t): + alpha0,beta,f,phi = list(*parameters) + return alpha0 + beta*jnp.sin(2*jnp.pi*f*t+phi) + +def Displacement_Foil_Fourier_Dotted_Mutliple(parameters,t): + + alpha0=jnp.array(list(list(zip(*parameters))[0])) + f =jnp.array(list(list(zip(*parameters))[1])) + phi = jnp.array(list(list(zip(*parameters))[2])) + alpha = jnp.array(list(list(zip(*parameters))[3])) + beta = jnp.array(list(list(zip(*parameters))[4])) + p = jnp.array(list(list(zip(*parameters))[5])) + + size_parameters = alpha.shape[1] + N_particles =len(alpha) + + ## Create an array of the size (nparticles, nparameters) + frequencies = jnp.array([jnp.arange(1,size_parameters+1)]*N_particles) + + ## multiply and add arrays + + inside_function =jnp.add(2*jnp.pi*t*frequencies*f.reshape(N_particles,1),phi.reshape(N_particles,1)) + + alpha_1 = (alpha*jnp.sin(inside_function)).sum(axis=1) + + alpha_1 += p*(beta*jnp.cos(inside_function)).sum(axis=1) + + + return jnp.array([-alpha0*t,alpha_1]) + + +def rotation_Foil_Fourier_Dotted_Mutliple(parameters,t): + #alpha0,f,phi,alpha,beta,p = parameters + alpha0=jnp.array(list(list(zip(*parameters))[0])) + f =jnp.array(list(list(zip(*parameters))[1])) + phi = jnp.array(list(list(zip(*parameters))[2])) + alpha = jnp.array(list(list(zip(*parameters))[3])) + beta = jnp.array(list(list(zip(*parameters))[4])) + p = jnp.array(list(list(zip(*parameters))[5])) + + size_parameters = alpha.shape[1] + N_particles =len(alpha) + + ## Create an array of the size (nparticles, nparameters) + frequencies = jnp.array([jnp.arange(1,size_parameters+1)]*N_particles) + + ## multiply and add arrays + + inside_function =jnp.add(2*jnp.pi*t*frequencies*f.reshape(N_particles,1),phi.reshape(N_particles,1)) + + alpha_1 = (alpha*jnp.sin(inside_function)).sum(axis=1) + + alpha_1 += p*(beta*jnp.cos(inside_function)).sum(axis=1) + + #if N_particles>1: + return alpha0*t + alpha_1 + +def rotation_Foil_Fourier_Dotted_Mutliple_NORMALIZED(parameters,t): + #alpha0,f,phi,alpha,beta,p = parameters + alpha0=jnp.array(list(list(zip(*parameters))[0])) + f =jnp.array(list(list(zip(*parameters))[1])) + phi = jnp.array(list(list(zip(*parameters))[2])) + alpha = jnp.array(list(list(zip(*parameters))[3])) + beta = jnp.array(list(list(zip(*parameters))[4])) + theta_av = jnp.array(list(list(zip(*parameters))[5])) + p = jnp.array(list(list(zip(*parameters))[6])) + + size_parameters = alpha.shape[1] + N_particles =len(alpha) + + ## Create an array of the size (nparticles, nparameters) + frequencies = jnp.array([jnp.arange(1,size_parameters+1)]*N_particles) + + ## multiply and add arrays + + inside_function =jnp.add(2*jnp.pi*t*frequencies*f.reshape(N_particles,1),phi.reshape(N_particles,1)) + + alpha_1 = (alpha*jnp.sin(inside_function)).sum(axis=1) + + alpha_1 += p*(beta*jnp.cos(inside_function)).sum(axis=1) + + inside_function2 =jnp.add(2*jnp.pi*frequencies,phi.reshape(N_particles,1)) + + alpha_2 = (alpha*jnp.sin(inside_function2)).sum(axis=1) + + alpha_2 += p*(beta*jnp.cos(inside_function2)).sum(axis=1) + + inside_function3 =jnp.add(2*jnp.pi*frequencies*0.0,phi.reshape(N_particles,1)) + + alpha_3 = (alpha*jnp.sin(inside_function3)).sum(axis=1) + + alpha_3 += p*(beta*jnp.cos(inside_function3)).sum(axis=1) + + + #if N_particles>1: + return theta_av*(alpha0*t + alpha_1)/(alpha0 + alpha_2-alpha_3) + #return (alpha0 + alpha_2) + #else: + # return (alpha0*t + alpha_1)[0] diff --git a/jax_cfd/base/particle_class.py b/jax_cfd/base/particle_class.py new file mode 100644 index 0000000..66af17d --- /dev/null +++ b/jax_cfd/base/particle_class.py @@ -0,0 +1,195 @@ +import dataclasses +import numbers +import operator +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.tree_util import register_pytree_node_class +import numpy as np +from jax_ib.base import grids + + +Array = Union[np.ndarray, jax.Array] +IntOrSequence = Union[int, Sequence[int]] + +# There is currently no good way to indicate a jax "pytree" with arrays at its +# leaves. See https://jax.readthedocs.io/en/latest/jax.tree_util.html for more +# information about PyTrees and https://github.com/google/jax/issues/3340 for +# discussion of this issue. +PyTree = Any +@dataclasses.dataclass(init=False, frozen=True) +class Grid1d: + """Describes the size and shape for an Arakawa C-Grid. + + See https://en.wikipedia.org/wiki/Arakawa_grids. + + This class describes domains that can be written as an outer-product of 1D + grids. Along each dimension `i`: + - `shape[i]` gives the whole number of grid cells on a single device. + - `step[i]` is the width of each grid cell. + - `(lower, upper) = domain[i]` gives the locations of lower and upper + boundaries. The identity `upper - lower = step[i] * shape[i]` is enforced. + """ + shape: Tuple[int, ...] + step: Tuple[float, ...] + domain: Tuple[Tuple[float, float], ...] + + def __init__( + self, + shape: Sequence[int], + step: Optional[Union[float, Sequence[float]]] = None, + domain: Optional[Union[float, Sequence[Tuple[float, float]]]] = None, + ): + """Construct a grid object.""" + shape = shape + object.__setattr__(self, 'shape', shape) + + + + object.__setattr__(self, 'domain', domain) + + step = (domain[1] - domain[0]) / (shape-1) + object.__setattr__(self, 'step', step) + + @property + def ndim(self) -> int: + """Returns the number of dimensions of this grid.""" + return 1 + + @property + def cell_center(self) -> Tuple[float, ...]: + """Offset at the center of each grid cell.""" + return self.ndim * (0.5,) + + + + def axes(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]: + """Returns a tuple of arrays containing the grid points along each axis. + + Args: + offset: an optional sequence of length `ndim`. The grid will be shifted by + `offset * self.step`. + + Returns: + An tuple of `self.ndim` arrays. The jth return value has shape + `[self.shape[j]]`. + """ + if offset is None: + offset = self.cell_center + if len(offset) != self.ndim: + raise ValueError(f'unexpected offset length: {len(offset)} vs ' + f'{self.ndim}') + + return self.domain[0] + jnp.arange(self.shape)*self.step + + + + def mesh(self, offset: Optional[Sequence[float]] = None) -> Tuple[Array, ...]: + """Returns an tuple of arrays containing positions in each grid cell. + + Args: + offset: an optional sequence of length `ndim`. The grid will be shifted by + `offset * self.step`. + + Returns: + An tuple of `self.ndim` arrays, each of shape `self.shape`. In 3 + dimensions, entry `self.mesh[n][i, j, k]` is the location of point + `i, j, k` in dimension `n`. + """ + + return self.axes(offset) + + + + +@register_pytree_node_class +@dataclasses.dataclass +class particle: + particle_center: Sequence[Any] + geometry_param: Sequence[Any] + displacement_param: Sequence[Any] + rotation_param: Sequence[Any] + Grid: Grid1d + shape: Callable + Displacement_EQ: Callable + Rotation_EQ: Callable + + + + + + def tree_flatten(self): + """Returns flattening recipe for GridVariable JAX pytree.""" + children = (self.particle_center,self.geometry_param,self.displacement_param,self.rotation_param,) + + aux_data = (self.Grid,self.shape,self.Displacement_EQ,self.Rotation_EQ,) + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for GridVariable JAX pytree.""" + return cls(*children,*aux_data) + + def generate_grid(self): + + return self.Grid.mesh() + + def calc_Rtheta(self): + return self.shape(self.geometry_param,self.Grid) + + + + +@register_pytree_node_class +@dataclasses.dataclass +class All_Variables: + particles: Sequence[particle,] + velocity: grids.GridVariableVector + pressure: grids.GridVariable + Drag:Sequence[Any] + Step_count:int + MD_var:Any + def tree_flatten(self): + """Returns flattening recipe for GridVariable JAX pytree.""" + children = (self.particles,self.velocity,self.pressure,self.Drag,self.Step_count,self.MD_var,) + + aux_data = None + return children, aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for GridVariable JAX pytree.""" + return cls(*children) + + + + +@register_pytree_node_class +@dataclasses.dataclass +class particle_lista: # SEQUENCE OF VARIABLES MATTER ! + particles: Sequence[particle,] + + + def generate_grid(self): + + return np.stack([grid.mesh() for grid in self.Grid]) + + def calc_Rtheta(self): + return self.shape(self.geometry_param,self.Grid) + + def tree_flatten(self): + """Returns flattening recipe for GridVariable JAX pytree.""" + children = (*self.particles,) + aux_data = None + return children,aux_data + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Returns unflattening recipe for GridVariable JAX pytree.""" + return cls(*children) + + + + + diff --git a/jax_cfd/base/particle_motion.py b/jax_cfd/base/particle_motion.py new file mode 100644 index 0000000..a80e236 --- /dev/null +++ b/jax_cfd/base/particle_motion.py @@ -0,0 +1,66 @@ +from jax_ib.base import particle_class as pc +import jax +import jax.numpy as jnp + + +def Update_particle_position_Multiple_and_MD_Step(step_fn,all_variables,dt): + particles = all_variables.particles + Drag = all_variables.Drag + velocity = all_variables.velocity + current_t =velocity[0].bc.time_stamp + particle_centers = particles.particle_center + Displacement_EQ = particles.Displacement_EQ + displacement_param = particles.displacement_param + New_eq = lambda t:Displacement_EQ(displacement_param,t) + dx_dt = jax.jacrev(New_eq) + + + #MD_var = step_fn(MD_var) + + U0 =dx_dt(current_t) + #print(U0) + Newparticle_center = jnp.array([particle_centers[:,0]+dt*U0[0],particle_centers[:,1]+dt*U0[1]]).T + #print(Newparticle_center) + mygrids = particles.Grid + param_geometry = particles.geometry_param + shape_fn = particles.shape + pressure = all_variables.pressure + Step_count = all_variables.Step_count + 1 + rotation_param = particles.rotation_param + + MD_var = step_fn(all_variables) + + New_particles = pc.particle(Newparticle_center,param_geometry,displacement_param,rotation_param,mygrids,shape_fn,Displacement_EQ,particles.Rotation_EQ) + + return pc.All_Variables(New_particles,velocity,pressure,Drag,Step_count,MD_var) + + +def Update_particle_position_Multiple(all_variables,dt): + particles = all_variables.particles + Drag = all_variables.Drag + velocity = all_variables.velocity + current_t =velocity[0].bc.time_stamp + particle_centers = particles.particle_center + Displacement_EQ = particles.Displacement_EQ + displacement_param = particles.displacement_param + New_eq = lambda t:Displacement_EQ(displacement_param,t) + dx_dt = jax.jacrev(New_eq) + + + + U0 =dx_dt(current_t) + #print(U0) + Newparticle_center = jnp.array([particle_centers[:,0]+dt*U0[0],particle_centers[:,1]+dt*U0[1]]).T + #print(Newparticle_center) + mygrids = particles.Grid + param_geometry = particles.geometry_param + shape_fn = particles.shape + pressure = all_variables.pressure + Step_count = all_variables.Step_count + 1 + rotation_param = particles.rotation_param + + MD_var = all_variables.MD_var + + New_particles = pc.particle(Newparticle_center,param_geometry,displacement_param,rotation_param,mygrids,shape_fn,Displacement_EQ,particles.Rotation_EQ) + + return pc.All_Variables(New_particles,velocity,pressure,Drag,Step_count,MD_var) diff --git a/jax_cfd/base/pressure.py b/jax_cfd/base/pressure.py index dca8160..d748bf8 100644 --- a/jax_cfd/base/pressure.py +++ b/jax_cfd/base/pressure.py @@ -1,29 +1,16 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions for computing and applying pressure.""" -from typing import Callable, Optional - -import jax.numpy as jnp -import jax.scipy.sparse.linalg -from jax_cfd.base import array_utils -from jax_cfd.base import boundaries +from typing import Callable, Optional +import scipy.linalg +import numpy as np +from jax_ib.base import array_utils from jax_cfd.base import fast_diagonalization -from jax_cfd.base import finite_differences as fd -from jax_cfd.base import grids +import jax.numpy as jnp +from jax_cfd.base import pressure +from jax_ib.base import grids +from jax_ib.base import boundaries +from jax_ib.base import finite_differences as fd +from jax_ib.base import particle_class Array = grids.Array GridArray = grids.GridArray @@ -32,11 +19,15 @@ GridVariableVector = grids.GridVariableVector BoundaryConditions = grids.BoundaryConditions -# Specifying the full signatures of Callable would get somewhat onerous -# pylint: disable=g-bare-generic - -# TODO(pnorgaard) Implement bicgstab for non-symmetric operators +def laplacian_matrix_neumann(size: int, step: float) -> np.ndarray: + """Create 1D Laplacian operator matrix, with homogeneous Neumann BC.""" + column = np.zeros(size) + column[0] = -2 / step ** 2 + column[1] = 1 / step ** 2 + matrix = scipy.linalg.toeplitz(column) + matrix[0, 0] = matrix[-1, -1] = -1 / step**2 + return matrix def _rhs_transform( @@ -62,123 +53,107 @@ def _rhs_transform( # functions. We substact the mean to ensure consistency. u_data = u_data - jnp.mean(u_data) return u_data + +def projection_and_update_pressure( + All_variables: particle_class.All_Variables, + solve: Callable = pressure.solve_fast_diag, +) -> GridVariableVector: + """Apply pressure projection to make a velocity field divergence free.""" + v = All_variables.velocity + old_pressure = All_variables.pressure + particles = All_variables.particles + Drag = All_variables.Drag + Step_count = All_variables.Step_count + MD_var = All_variables.MD_var + grid = grids.consistent_grid(*v) + pressure_bc = boundaries.get_pressure_bc_from_velocity(v) -def solve_cg( - v: GridVariableVector, - q0: GridVariable, - pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None, - rtol: float = 1e-6, - atol: float = 1e-6, - maxiter: Optional[int] = None) -> GridArray: - """Conjugate gradient solve for the pressure such that continuity is enforced. - - Returns a pressure correction `q` such that `div(v - grad(q)) == 0`. - - The relationship between `q` and our actual pressure estimate is given by - `p = q * density / dt`. - - Args: - v: the velocity field. - q0: an initial value, or "guess" for the pressure correction. A common - choice is the correction from the previous time step. Also specifies the - boundary conditions on `q`. - pressure_bc: the boundary condition to assign to pressure. If None, - boundary condition is infered from velocity. - rtol: relative tolerance for convergence. - atol: absolute tolerance for convergence. - maxiter: optional int, the maximum number of iterations to perform. - - Returns: - A pressure correction `q` such that `div(v - grad(q))` is zero. - """ - # TODO(jamieas): add functionality for non-uniform density. - rhs = fd.divergence(v) - - if pressure_bc is None: - pressure_bc = boundaries.get_pressure_bc_from_velocity(v) + q0 = grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid) + q0 = grids.GridVariable(q0, pressure_bc) - def laplacian_with_bcs(array: GridArray) -> GridArray: - variable = pressure_bc.impose_bc(array) - return fd.laplacian(variable) + qsol = solve(v, q0) + q = grids.GridVariable(qsol, pressure_bc) + + New_pressure_Array = grids.GridArray(qsol.data + old_pressure.data,qsol.offset,qsol.grid) + New_pressure = grids.GridVariable(New_pressure_Array,pressure_bc) - q, _ = jax.scipy.sparse.linalg.cg( - laplacian_with_bcs, - rhs, - x0=q0.array, - tol=rtol, - atol=atol, - maxiter=maxiter) - return q + q_grad = fd.forward_difference(q) + if boundaries.has_all_periodic_boundary_conditions(*v): + v_projected = tuple( + grids.GridVariable(u.array - q_g, u.bc) for u, q_g in zip(v, q_grad)) + new_variable = particle_class.All_Variables(particles,v_projected,New_pressure,Drag,Step_count,MD_var) + else: + v_projected = tuple( + grids.GridVariable(u.array - q_g, u.bc).impose_bc() + for u, q_g in zip(v, q_grad)) + new_variable = particle_class.All_Variables(particles,v_projected,New_pressure,Drag,Step_count,MD_var) + return new_variable def solve_fast_diag( v: GridVariableVector, - q0: Optional[grids.GridArray] = None, - pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None, - implementation: Optional[str] = None, -) -> grids.GridArray: - """Solve for pressure using the fast diagonalization approach. - - To support backward compatibility, if the pressure_bc are not provided and - velocity has all periodic boundaries, pressure_bc are assigned to be periodic. - - Args: - v: a tuple of velocity values for each direction. - q0: the starting guess for the pressure. - pressure_bc: the boundary condition to assign to pressure. If None, - boundary condition is infered from velocity. - implementation: how to implement fast diagonalization. - For non-periodic BCs will automatically be matmul. - - - Returns: - A solution to the PPE equation. - """ + q0: Optional[GridVariable] = None, + implementation: Optional[str] = None) -> GridArray: + """Solve for pressure using the fast diagonalization approach.""" del q0 # unused - if pressure_bc is None: - pressure_bc = boundaries.get_pressure_bc_from_velocity(v) - if boundaries.has_all_periodic_boundary_conditions(*v): - circulant = True - else: - circulant = False - # only matmul implementation supports non-circulant matrices - implementation = 'matmul' + if not boundaries.has_all_periodic_boundary_conditions(*v): + raise ValueError('solve_fast_diag() expects periodic velocity BC') + grid = grids.consistent_grid(*v) rhs = fd.divergence(v) - laplacians = array_utils.laplacian_matrix_w_boundaries( - rhs.grid, rhs.offset, pressure_bc) - rhs_transformed = _rhs_transform(rhs, pressure_bc) + laplacians = list(map(array_utils.laplacian_matrix, grid.shape, grid.step)) pinv = fast_diagonalization.pseudoinverse( - laplacians, - rhs_transformed.dtype, - hermitian=True, - circulant=circulant, - implementation=implementation) - return grids.GridArray(pinv(rhs_transformed), rhs.offset, rhs.grid) + laplacians, rhs.dtype, + hermitian=True, circulant=True, implementation=implementation) + return grids.applied(pinv)(rhs) -def solve_fast_diag_channel_flow( +def solve_fast_diag_moving_wall( v: GridVariableVector, - q0: Optional[grids.GridArray] = None, - pressure_bc: Optional[boundaries.ConstantBoundaryConditions] = None, -) -> grids.GridArray: - """Applies solve_fast_diag for channel flow. - - Args: - v: a tuple of velocity values for each direction. - q0: the starting guess for the pressure. - pressure_bc: the boundary condition to assign to pressure. If None, - boundary condition is infered from velocity. + q0: Optional[GridVariable] = None, + implementation: Optional[str] = 'matmul') -> GridArray: + """Solve for channel flow pressure using fast diagonalization.""" + del q0 # unused + ndim = len(v) - Returns: - A solutiion to the PPE equation. - """ - if pressure_bc is None: - pressure_bc = boundaries.get_pressure_bc_from_velocity(v) - return solve_fast_diag(v, q0, pressure_bc, implementation='matmul') + grid = grids.consistent_grid(*v) + rhs = fd.divergence(v) + laplacians = [ + array_utils.laplacian_matrix(grid.shape[0], grid.step[0]), + array_utils.laplacian_matrix_neumann(grid.shape[1], grid.step[1]), + ] + for d in range(2, ndim): + laplacians += [array_utils.laplacian_matrix(grid.shape[d], grid.step[d])] + pinv = fast_diagonalization.pseudoinverse( + laplacians, rhs.dtype, + hermitian=True, circulant=False, implementation=implementation) + return grids.applied(pinv)(rhs) + + + +def solve_fast_diag_Far_Field( + v: GridVariableVector, + q0: Optional[GridVariable] = None, + implementation: Optional[str] = None) -> GridArray: + """Solve for pressure using the fast diagonalization approach.""" + del q0 # unused + grid = grids.consistent_grid(*v) + rhs = fd.divergence(v) + pressure_bc = boundaries.get_pressure_bc_from_velocity(v) + rhs_transformed = _rhs_transform(rhs, pressure_bc) + #laplacians = [ + # laplacian_matrix_neumann(grid.shape[0], grid.step[0]), + # laplacian_matrix_neumann(grid.shape[1], grid.step[1]), + #] + laplacians = array_utils.laplacian_matrix_w_boundaries( + rhs.grid, rhs.offset, pressure_bc) + pinv = fast_diagonalization.pseudoinverse( + laplacians, rhs_transformed.dtype, + hermitian=True, circulant=False, implementation='matmul') + return grids.applied(pinv)(rhs) -def projection( +def calc_P( v: GridVariableVector, solve: Callable = solve_fast_diag, ) -> GridVariableVector: @@ -187,12 +162,9 @@ def projection( pressure_bc = boundaries.get_pressure_bc_from_velocity(v) q0 = grids.GridArray(jnp.zeros(grid.shape), grid.cell_center, grid) - q0 = pressure_bc.impose_bc(q0) + q0 = grids.GridVariable(q0, pressure_bc) - q = solve(v, q0, pressure_bc) - q = pressure_bc.impose_bc(q) - q_grad = fd.forward_difference(q) - v_projected = tuple( - u.bc.impose_bc(u.array - q_g) for u, q_g in zip(v, q_grad)) + q = solve(v, q0) + q = grids.GridVariable(q, pressure_bc) - return v_projected + return q diff --git a/jax_cfd/base/time_stepping.py b/jax_cfd/base/time_stepping.py index b9db210..6f8488f 100644 --- a/jax_cfd/base/time_stepping.py +++ b/jax_cfd/base/time_stepping.py @@ -1,29 +1,50 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Time stepping for Navier-Stokes equations.""" - import dataclasses from typing import Callable, Sequence, TypeVar import jax import tree_math +from jax_ib.base import boundaries +from jax_ib.base import grids +from jax_cfd.base import time_stepping +from jax_ib.base import particle_class PyTreeState = TypeVar("PyTreeState") TimeStepFn = Callable[[PyTreeState], PyTreeState] -class ExplicitNavierStokesODE: +class ExplicitNavierStokesODE_Penalty: + """Spatially discretized version of Navier-Stokes. + + The equation is given by: + + ∂u/∂t = explicit_terms(u) + 0 = incompressibility_constraint(u) + """ + + def __init__(self, explicit_terms, pressure_projection,update_BC,Reserve_BC): + self.explicit_terms = explicit_terms + self.pressure_projection = pressure_projection + self.update_BC = update_BC + self.Reserve_BC = Reserve_BC + + + def explicit_terms(self, state): + """Explicitly evaluate the ODE.""" + raise NotImplementedError + + def pressure_projection(self, state): + """Enforce the incompressibility constraint.""" + raise NotImplementedError + + def update_BC(self, state): + """Update Wall BC """ + raise NotImplementedError + + def Reserve_BC(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError + +class ExplicitNavierStokesODE_BCtime: """Spatially discretized version of Navier-Stokes. The equation is given by: @@ -32,9 +53,15 @@ class ExplicitNavierStokesODE: 0 = incompressibility_constraint(u) """ - def __init__(self, explicit_terms, pressure_projection): + def __init__(self, explicit_terms, pressure_projection,update_BC,Reserve_BC,IBM_force,Update_Position,Pressure_Grad,Calculate_Drag): self.explicit_terms = explicit_terms self.pressure_projection = pressure_projection + self.update_BC = update_BC + self.Reserve_BC = Reserve_BC + self.IBM_force = IBM_force + self.Update_Position = Update_Position + self.Pressure_Grad = Pressure_Grad + self.Calculate_Drag = Calculate_Drag def explicit_terms(self, state): """Explicitly evaluate the ODE.""" @@ -44,21 +71,194 @@ def pressure_projection(self, state): """Enforce the incompressibility constraint.""" raise NotImplementedError + def update_BC(self, state): + """Update Wall BC """ + raise NotImplementedError + + def Reserve_BC(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError + def IBM_force(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError + + def Update_Position(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError + + def Pressure_Grad(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError + + def Calculate_Drag(self, state): + """Revert spurious updates of Wall BC """ + raise NotImplementedError @dataclasses.dataclass -class ButcherTableau: +class ButcherTableau_updated: a: Sequence[Sequence[float]] b: Sequence[float] + c: Sequence[float] # TODO(shoyer): add c, when we support time-dependent equations. def __post_init__(self): if len(self.a) + 1 != len(self.b): raise ValueError("inconsistent Butcher tableau") + + +def navier_stokes_rk_updated( + tableau: ButcherTableau_updated, + equation: ExplicitNavierStokesODE_BCtime, + time_step: float, +) -> TimeStepFn: + """Create a forward Runge-Kutta time-stepper for incompressible Navier-Stokes. + This function implements the reference method (equations 16-21), rather than + the fast projection method, from: + "Fast-Projection Methods for the Incompressible Navier–Stokes Equations" + Fluids 2020, 5, 222; doi:10.3390/fluids5040222 -def navier_stokes_rk( - tableau: ButcherTableau, - equation: ExplicitNavierStokesODE, + Args: + tableau: Butcher tableau. + equation: equation to use. + time_step: overall time-step size. + + Returns: + Function that advances one time-step forward. + """ + # pylint: disable=invalid-name + dt = time_step + F = tree_math.unwrap(equation.explicit_terms) + P = tree_math.unwrap(equation.pressure_projection) + M = tree_math.unwrap(equation.update_BC) + R = tree_math.unwrap(equation.Reserve_BC) + IBM = tree_math.unwrap(equation.IBM_force) + Update_Pos = tree_math.unwrap(equation.Update_Position) + Grad_Pressure = tree_math.unwrap(equation.Pressure_Grad) + Drag_Calculation = tree_math.unwrap(equation.Calculate_Drag) + + a = tableau.a + b = tableau.b + num_steps = len(b) + + @tree_math.wrap + def step_fn(u0): + #print('vector',u0) + #new_time = 0#u0[0].bc.time_stamp + dt + u = [None] * num_steps + k = [None] * num_steps + + def convert_to_velocity_vecot(u0): + u = u0.tree + return tree_math.Vector(tuple(u[i].array for i in range(len(u)))) + + def convert_to_velocity_tree(m,bcs): + return tree_math.Vector(tuple(grids.GridVariable(v,bc) for v,bc in zip(m.tree,bcs))) + + def convert_all_variabl_to_velocity_vecot(u0): + u = u0.tree.velocity + #return tree_math.Vector(tuple(grids.GridVariable(v.array,v.bc) for v in u)) + return tree_math.Vector(u) + def covert_veloicty_to_All_variable_vecot(particles,m,pressure,Drag,Step_count,MD_var): + u = m.tree + #return tree_math.Vector(particle_class.All_Variables(particles, tuple(grids.GridVariable(v.array,v.bc) for v in u),pressure)) + return tree_math.Vector(particle_class.All_Variables(particles,u,pressure,Drag,Step_count,MD_var)) + + def velocity_bc(u0): + u = u0.tree.velocity + return tuple(u[i].bc for i in range(len(u))) + + def the_particles(u0): + return u0.tree.particles + def the_pressure(u0): + return u0.tree.pressure + def the_Drag(u0): + return u0.tree.Drag + + + particles = the_particles(u0) + ubc = velocity_bc(u0) + pressure = the_pressure(u0) + Drag = the_Drag(u0) + Step_count = u0.tree.Step_count + MD_var = u0.tree.MD_var + + + u0 = convert_all_variabl_to_velocity_vecot(u0) + + + u[0] = convert_to_velocity_vecot(u0) + k[0] = convert_to_velocity_vecot(F(u0)) + dP = Grad_Pressure(tree_math.Vector(pressure)) + + + + u0 = convert_to_velocity_vecot(u0) + + for i in range(1, num_steps): + #u_star = u0[ww].array + sum(a[i-1][j]*k[j][ww].array for j in range(i) if a[i-1][j]) + + u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j]) + + #u[i] = P(R(u_star)) + u[i] = convert_to_velocity_vecot(P(convert_to_velocity_tree(u_star,ubc))) + k[i] = convert_to_velocity_vecot(F(convert_to_velocity_tree(u[i],ubc))) + + #for ww in range(0,len(u0)): + u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j])-dP + + Force = IBM(covert_veloicty_to_All_variable_vecot(particles,convert_to_velocity_tree(u_star,ubc),pressure,Drag,Step_count,MD_var)) + + + Drag_variable = Drag_Calculation(covert_veloicty_to_All_variable_vecot(particles,Force,pressure,Drag,Step_count,MD_var)) + Drag = the_Drag(Drag_variable) + + + Force = convert_to_velocity_vecot(Force) + + #Tree_force = convert_to_velocity_tree(Force,ubc) + + + + u_star_star = u_star + dt * Force + + # for i in range(0,2): + # Force = IBM(covert_veloicty_to_All_variable_vecot(particles,convert_to_velocity_tree(u_star_star,ubc),pressure,Drag)) + # if i==1: + # Drag_variable = Drag_Calculation(covert_veloicty_to_All_variable_vecot(particles,Force,pressure,Drag)) + # Drag = the_Drag(Drag_variable) + # Force = convert_to_velocity_vecot(Force) + # u_star_star = u_star+ dt * Force + + #u_final = P(R(u_star)) + #u_final = P(Force) + + + + + u_final = convert_to_velocity_tree(u_star_star,ubc) + + + u_final = covert_veloicty_to_All_variable_vecot(particles,u_final,pressure,Drag,Step_count,MD_var) + + + + u_final = P(u_final) + #u_final = P(u_star_star) + u_final = M(u_final) + + + + + u_final = Update_Pos(u_final) # the time step counter is also updated + + return u_final + + return step_fn + +def navier_stokes_rk_penalty( + tableau: ButcherTableau_updated, + equation: ExplicitNavierStokesODE_BCtime, time_step: float, ) -> TimeStepFn: """Create a forward Runge-Kutta time-stepper for incompressible Navier-Stokes. @@ -80,79 +280,110 @@ def navier_stokes_rk( dt = time_step F = tree_math.unwrap(equation.explicit_terms) P = tree_math.unwrap(equation.pressure_projection) + M = tree_math.unwrap(equation.update_BC) + R = tree_math.unwrap(equation.Reserve_BC) a = tableau.a b = tableau.b num_steps = len(b) - + @tree_math.wrap def step_fn(u0): + #print('vector',u0) + #new_time = 0#u0[0].bc.time_stamp + dt u = [None] * num_steps k = [None] * num_steps - u[0] = u0 - k[0] = F(u0) - + def convert_to_velocity_vecot(u0): + u = u0.tree + return tree_math.Vector(tuple(u[i].array for i in range(len(u)))) + + def convert_to_velocity_tree(m,bcs): + return tree_math.Vector(tuple(grids.GridVariable(v,bc) for v,bc in zip(m.tree,bcs))) + + def convert_all_variabl_to_velocity_vecot(u0): + u = u0.tree.velocity + #return tree_math.Vector(tuple(grids.GridVariable(v.array,v.bc) for v in u)) + return tree_math.Vector(u) + def covert_veloicty_to_All_variable_vecot(particles,m,pressure,Drag,Step_count,MD_var): + u = m.tree + #return tree_math.Vector(particle_class.All_Variables(particles, tuple(grids.GridVariable(v.array,v.bc) for v in u),pressure)) + return tree_math.Vector(particle_class.All_Variables(particles,u,pressure,Drag,Step_count,MD_var)) + + def velocity_bc(u0): + u = u0.tree.velocity + return tuple(u[i].bc for i in range(len(u))) + + def the_particles(u0): + return u0.tree.particles + def the_pressure(u0): + return u0.tree.pressure + def the_Drag(u0): + return u0.tree.Drag + + particles = the_particles(u0) + ubc = velocity_bc(u0) + pressure = the_pressure(u0) + Drag = the_Drag(u0) + Step_count = u0.tree.Step_count + MD_var = u0.tree.MD_var + + + u0 = convert_all_variabl_to_velocity_vecot(u0) + + u[0] = convert_to_velocity_vecot(u0) + k[0] = convert_to_velocity_vecot(F(u0)) + + + + u0 = convert_to_velocity_vecot(u0) + for i in range(1, num_steps): + #u_star = u0[ww].array + sum(a[i-1][j]*k[j][ww].array for j in range(i) if a[i-1][j]) + u_star = u0 + dt * sum(a[i-1][j] * k[j] for j in range(i) if a[i-1][j]) - u[i] = P(u_star) - k[i] = F(u[i]) + + #u[i] = P(R(u_star)) + u[i] = convert_to_velocity_vecot(P(convert_to_velocity_tree(u_star,ubc))) + k[i] = convert_to_velocity_vecot(F(convert_to_velocity_tree(u[i],ubc))) + #for ww in range(0,len(u0)): u_star = u0 + dt * sum(b[j] * k[j] for j in range(num_steps) if b[j]) - u_final = P(u_star) + u_final = convert_to_velocity_tree(u_star,ubc) + + u_final = covert_veloicty_to_All_variable_vecot(particles,u_final,pressure,Drag,Step_count,MD_var) + u_final = P(u_final) + # + u_final = M(u_final) + + + + return u_final return step_fn - -def forward_euler( - equation: ExplicitNavierStokesODE, time_step: float, +def forward_euler_penalty( + equation: ExplicitNavierStokesODE_Penalty, time_step: float, ) -> TimeStepFn: return jax.named_call( - navier_stokes_rk( - ButcherTableau(a=[], b=[1]), + navier_stokes_rk_penalty( + ButcherTableau_updated(a=[], b=[1], c=[0]), equation, time_step), name="forward_euler", ) - -def midpoint_rk2( - equation: ExplicitNavierStokesODE, time_step: float, -) -> TimeStepFn: - return jax.named_call( - navier_stokes_rk( - ButcherTableau(a=[[1/2]], b=[0, 1]), - equation=equation, - time_step=time_step, - ), - name="midpoint_rk2", - ) - - -def heun_rk2( - equation: ExplicitNavierStokesODE, time_step: float, +def forward_euler_updated( + equation: ExplicitNavierStokesODE_BCtime, time_step: float, ) -> TimeStepFn: return jax.named_call( - navier_stokes_rk( - ButcherTableau(a=[[1]], b=[1/2, 1/2]), - equation=equation, - time_step=time_step, - ), - name="heun_rk2", + navier_stokes_rk_updated( + ButcherTableau_updated(a=[], b=[1], c=[0]), + equation, + time_step), + name="forward_euler", ) -def classic_rk4( - equation: ExplicitNavierStokesODE, time_step: float, -) -> TimeStepFn: - return jax.named_call( - navier_stokes_rk( - ButcherTableau(a=[[1/2], [0, 1/2], [0, 0, 1]], - b=[1/6, 1/3, 1/3, 1/6]), - equation=equation, - time_step=time_step, - ), - name="classic_rk4", - )