diff --git a/src/enzax/steady_state.py b/src/enzax/steady_state.py index 8aa6b18..5f7e3bf 100644 --- a/src/enzax/steady_state.py +++ b/src/enzax/steady_state.py @@ -28,28 +28,21 @@ def dC_dt_sqrd( @eqx.filter_jit() def lagrangian( - z: Float[Array, " n_balanced*2"], + x: Float[Array, " n_balanced"], model: RateEquationModel, -) -> Float[Array, " n_balanced*2"]: - n_balanced = len(model.structure.balanced_species) - F = jnp.ones((2*n_balanced,1)) - x = jnp.exp(z[0:n_balanced]) +) -> Float[Array, " n_balanced"]: + x = x conc = jnp.zeros(model.structure.S.shape[0]) conc = conc.at[model.structure.balanced_species].set(x) conc = conc.at[model.structure.unbalanced_species].set( jnp.exp(model.parameters.log_conc_unbalanced) ) - lamb = z[n_balanced:] - ddc_dt_sqrd_dc = jax.grad(dC_dt_sqrd, argnums=1)(model, x, conc) - ddc_dt_dc = jax.jacfwd(model.dcdt, argnums=1)(0, x) - F = F.at[0:n_balanced, 0].set(ddc_dt_sqrd_dc - jnp.multiply(lamb,ddc_dt_dc).sum(axis=0)) - F = F.at[n_balanced:, 0].set(model.dcdt(0, x)) - return F.T[0] + jac = jax.jacfwd(model.dcdt, argnums=1)(0, x) + return -jnp.linalg.inv(jac)@model.dcdt(0, x) @eqx.filter_jit() def get_steady_state_lagrangian( guess: Float[Array, " n_balanced"], - lambda_guess: Float[Array, " n_balanced"], model: RateEquationModel, ) -> Float[Array, " n_balanced"]: """Get the steady state of a kinetic model, using optimistix. @@ -72,16 +65,15 @@ def get_steady_state_lagrangian( :param model: a KineticModel object """ n_balanced = len(model.structure.balanced_species) - solver = optx.Dogleg(rtol=1e-2, atol=1e-5) + solver = optx.Newton(rtol=1e-8, atol=1e-10) sol = optx.root_find( lagrangian, solver, - jnp.concat([jnp.log(guess), lambda_guess]), + guess, args=model, max_steps=int(1e5), ) - opt_conc = jnp.exp(sol.value[0:n_balanced]) - return opt_conc + return sol.value @eqx.filter_jit() def get_kinetic_model_steady_state(