diff --git a/jpc/_train.py b/jpc/_train.py index 75e8c47..204bc93 100644 --- a/jpc/_train.py +++ b/jpc/_train.py @@ -21,7 +21,7 @@ pc_energy_fn, compute_infer_energies, compute_pc_param_grads, - compute_grad_norms, + compute_param_norms, compute_accuracy, hpc_energy_fn, compute_hpc_param_grads @@ -54,6 +54,7 @@ def make_pc_step( record_energies: bool = False, record_every: int = None, activity_norms: bool = False, + param_norms: bool = False, grad_norms: bool = False, calculate_accuracy: bool = False ) -> Dict: @@ -75,8 +76,8 @@ def make_pc_step( **Other arguments:** - - `loss`: Loss function to use at the output layer (mean squared error - 'MSE' vs cross-entropy 'CE'). + - `loss_id`: Loss function for the output layer (mean squared error 'MSE' + vs cross-entropy 'CE'). - `ode_solver`: Diffrax ODE solver to be used. Default is Heun, a 2nd order explicit Runge--Kutta method. - `max_t1`: Maximum end of integration region (20 by default). @@ -98,15 +99,16 @@ def make_pc_step( inference iteration. - `record_every`: int determining the sampling frequency the integration steps. - - `activity_norms`: If `True`, computes norm of the activities. - - `grad_norms`: If `True`, computes norm of parameter gradients. + - `activity_norms`: If `True`, computes l2 norm of the activities. + - `param_norms`: If `True`, computes l2 norm of the parameters. + - `grad_norms`: If `True`, computes l2 norm of parameter gradients. - `calculate_accuracy`: If `True`, computes the training accuracy. **Returns:** - Dict including model with updated parameters, optimiser, updated optimiser - state, loss, energies, equilibrated activities, and optionally other - metrics. + Dict including model (and optional skip model) with updated parameters, + optimiser, updated optimiser state, loss, energies, activities, + and optionally other metrics (see other args above). **Raises:** @@ -147,7 +149,7 @@ def make_pc_step( activities=activities, y=output, x=input, - loss=loss_id, + loss_id=loss_id, solver=ode_solver, max_t1=max_t1, dt=dt, @@ -155,9 +157,12 @@ def make_pc_step( record_iters=record_activities, record_every=record_every ) - activity_norms = (compute_activity_norms(equilib_activities) - if activity_norms else None) - t_max = get_t_max(equilib_activities) if record_activities else None + t_max = get_t_max(equilib_activities) if record_activities else 0 + activity_norms = (compute_activity_norms( + activities=tree_map( + lambda act: act[t_max], equilib_activities + ) + ) if activity_norms else None) energies = compute_infer_energies( params=(model, skip_model), activities_iters=equilib_activities, @@ -177,6 +182,9 @@ def make_pc_step( record_layers=True ) + param_norms = compute_param_norms( + (model, skip_model) + ) if param_norms else (None, None) param_grads = compute_pc_param_grads( params=(model, skip_model), activities=tree_map( @@ -185,9 +193,9 @@ def make_pc_step( ), y=output, x=input, - loss=loss_id + loss_id=loss_id ) - grad_norms = compute_grad_norms(param_grads) if grad_norms else (None, None) + grad_norms = compute_param_norms(param_grads) if grad_norms else (None, None) updates, opt_state = optim.update( updates=param_grads, state=opt_state, @@ -217,6 +225,8 @@ def make_pc_step( "t_max": t_max, "energies": energies, "activity_norms": activity_norms, + "model_param_norms": param_norms[0], + "skip_model_param_norms": param_norms[1], "model_grad_norms": grad_norms[0], "skip_model_grad_norms": grad_norms[1] }