Skip to content

Commit

Permalink
Some mumbling
Browse files Browse the repository at this point in the history
  • Loading branch information
arpastrana committed Jan 23, 2024
1 parent be30b5e commit 02c1a3b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 40 deletions.
78 changes: 44 additions & 34 deletions src/jax_fdm/equilibrium/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def nodes_free_positions(self, q, xyz_fixed, loads, structure):
Calculate the XYZ coordinates of the free nodes.
"""
A = self.stiffness_matrix(q, structure)
b = self.force_matrix(q, xyz_fixed, loads, structure)
b = self.load_matrix(q, xyz_fixed, loads, structure)

return self.linearsolve_fn(A, b)

Expand Down Expand Up @@ -157,7 +157,8 @@ def __call__(self, params, structure):
implicit_diff = self.implicit_diff
verbose = self.verbose

xyz = self.equilibrium(q, xyz_fixed, loads_nodes, structure)
xyz_free = self.nodes_free_positions(q, xyz_fixed, loads_nodes, structure)
# xyz = self.equilibrium(q, xyz_fixed, loads_nodes, structure)

if tmax > 1:
# Setting node loads to zero when tmax > 1 if specified
Expand All @@ -167,17 +168,19 @@ def __call__(self, params, structure):
loads_state.edges,
loads_state.faces)

xyz = self.equilibrium_iterative(q,
xyz_fixed,
loads_state,
structure,
xyz_init=xyz,
tmax=tmax,
eta=eta,
solver=solver,
implicit_diff=implicit_diff,
verbose=verbose)
xyz_free = self.equilibrium_iterative(q,
xyz_fixed,
loads_state,
structure,
xyz_free_init=xyz_free,
tmax=tmax,
eta=eta,
solver=solver,
implicit_diff=implicit_diff,
verbose=verbose)

indices = structure.indices_freefixed
xyz = self.nodes_positions(xyz_free, xyz_fixed, indices)
loads_nodes = self.nodes_load(xyz, loads_state, structure)

return self.equilibrium_state(q, xyz, loads_nodes, structure)
Expand All @@ -197,7 +200,7 @@ def equilibrium_iterative(self,
xyz_fixed,
load_state,
structure,
xyz_init=None,
xyz_free_init=None,
tmax=100,
eta=1e-6,
solver=None,
Expand All @@ -211,31 +214,37 @@ def equilibrium_iterative(self,
This function only supports reverse mode auto-differentiation.
To support forward-mode, we should define a custom jvp using implicit differentiation.
"""
def equilibrium_iterative_fn(params, xyz_init):
def load_matrix_fn(xyz, load_state, structure, force_fixed_matrix):
"""
Calculate the effective loads matrix for a fixed-point iteration.
"""
free = structure.indices_free
loads_nodes = self.nodes_load(xyz, load_state, structure)

return loads_nodes[free, :] - force_fixed_matrix

def equilibrium_iterative_fn(params, xyz_free):
"""
This closure function avoids re-computing A and f_fixed throughout iterations
because these two matrices remain constant during the fixed point search.
TODO: Extract closure into function shared with the other nodes equilibrium function?
"""
A, f_fixed, xyz_fixed, load_state = params

free = structure.indices_free
K, F, xyz_fixed, load_state = params
freefixed = structure.indices_freefixed

loads_nodes = self.nodes_load(xyz_init, load_state, structure)
b = loads_nodes[free, :] - f_fixed
xyz_free = self.linearsolve_fn(A, b)
xyz_ = self.nodes_positions(xyz_free, xyz_fixed, freefixed)
xyz = self.nodes_positions(xyz_free, xyz_fixed, freefixed)
b = load_matrix_fn(xyz, load_state, structure, F)

return xyz_
return self.linearsolve_fn(K, b)

# recompute xyz_init if not input
if xyz_init is None:
xyz_init = self.equilibrium(q, xyz_fixed, load_state.nodes, structure)
if xyz_free_init is None:
# xyz_free_init = self.equilibrium(q, xyz_fixed, load_state.nodes, structure)
xyz_free_init = self.nodes_free_positions(q, xyz_fixed, load_state.nodes, structure)

A = self.stiffness_matrix(q, structure)
f_fixed = self.force_fixed_matrix(q, xyz_fixed, structure)
K = self.stiffness_matrix(q, structure)
F = self.force_fixed_matrix(q, xyz_fixed, structure)

solver = solver or self.itersolve_fn
solver_config = {"tmax": tmax,
Expand All @@ -246,15 +255,16 @@ def equilibrium_iterative_fn(params, xyz_init):

solver_kwargs = {"solver_config": solver_config,
"f": equilibrium_iterative_fn,
"a": (A, f_fixed, xyz_fixed, load_state),
"x_init": xyz_init}
"a": (K, F, xyz_fixed, load_state),
"x_init": xyz_free_init}

if implicit_diff:
xyz_new = fixed_point(solver, **solver_kwargs)

xyz_new = solver(**solver_kwargs)
# xyz_new = fixed_point(solver, **solver_kwargs)
return fixed_point(solver, **solver_kwargs)

return xyz_new
# xyz_new = solver(**solver_kwargs)
# return xyz_new
return solver(**solver_kwargs)

# ----------------------------------------------------------------------
# Equilibrium state
Expand Down Expand Up @@ -291,9 +301,9 @@ def stiffness_matrix(q, structure):

return c_free.T @ (q[:, None] * c_free)

def force_matrix(self, q, xyz_fixed, loads, structure):
def load_matrix(self, q, xyz_fixed, loads, structure):
"""
The force residual matrix of the structure.
The load matrix of the structure.
"""
free = structure.indices_free

Expand Down
46 changes: 40 additions & 6 deletions src/jax_fdm/equilibrium/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,53 @@ def fixed_point_bwd(solver, solver_config, fn, res, x_star_bar):
The backward pass of a fixed point solver.
"""
a, x_star = res

# equilibrium constraint
def residual_fn(params):
x_star, A, b = params
return A @ x_star - b

A = a[0]
params = (x_star, A, b)

_, vjp_params = vjp(residual_fn, params)


_, vjp_a = vjp(lambda a: fn(a, x_star), a)

def adjoint_iterative(packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: fn(a, x), x_star)
return x_star_bar + vjp_x(u)[0]

w = solver(adjoint_iterative,
(a, x_star, x_star_bar),
x_star_bar,
solver_config)

a_bar = vjp_a(w)[0]

return a_bar, None


def fixed_point_bwd_backup(solver, solver_config, fn, res, x_star_bar):
"""
The backward pass of a fixed point solver.
"""
a, x_star = res
_, vjp_a = vjp(lambda a: fn(a, x_star), a)

def rev_iter(packed, u):
def adjoint_iterative(packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: fn(a, x), x_star)
return x_star_bar + vjp_x(u)[0]

partial_func = solver(rev_iter,
(a, x_star, x_star_bar),
x_star_bar,
solver_config)
lam = solver(adjoint_iterative,
(a, x_star, x_star_bar),
x_star_bar,
solver_config)

a_bar = vjp_a(partial_func)[0]
a_bar = vjp_a(lam)[0]

return a_bar, None

Expand Down

0 comments on commit 02c1a3b

Please sign in to comment.