Skip to content

Commit

Permalink
Update main step function.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 19, 2024
1 parent d565321 commit 1244459
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions jpc/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
pc_energy_fn,
compute_infer_energies,
compute_pc_param_grads,
compute_grad_norms,
compute_param_norms,
compute_accuracy,
hpc_energy_fn,
compute_hpc_param_grads
Expand Down Expand Up @@ -54,6 +54,7 @@ def make_pc_step(
record_energies: bool = False,
record_every: int = None,
activity_norms: bool = False,
param_norms: bool = False,
grad_norms: bool = False,
calculate_accuracy: bool = False
) -> Dict:
Expand All @@ -75,8 +76,8 @@ def make_pc_step(
**Other arguments:**
- `loss`: Loss function to use at the output layer (mean squared error
'MSE' vs cross-entropy 'CE').
- `loss_id`: Loss function for the output layer (mean squared error 'MSE'
vs cross-entropy 'CE').
- `ode_solver`: Diffrax ODE solver to be used. Default is Heun, a 2nd order
explicit Runge--Kutta method.
- `max_t1`: Maximum end of integration region (20 by default).
Expand All @@ -98,15 +99,16 @@ def make_pc_step(
inference iteration.
- `record_every`: int determining the sampling frequency the integration
steps.
- `activity_norms`: If `True`, computes norm of the activities.
- `grad_norms`: If `True`, computes norm of parameter gradients.
- `activity_norms`: If `True`, computes l2 norm of the activities.
- `param_norms`: If `True`, computes l2 norm of the parameters.
- `grad_norms`: If `True`, computes l2 norm of parameter gradients.
- `calculate_accuracy`: If `True`, computes the training accuracy.
**Returns:**
Dict including model with updated parameters, optimiser, updated optimiser
state, loss, energies, equilibrated activities, and optionally other
metrics.
Dict including model (and optional skip model) with updated parameters,
optimiser, updated optimiser state, loss, energies, activities,
and optionally other metrics (see other args above).
**Raises:**
Expand Down Expand Up @@ -147,17 +149,20 @@ def make_pc_step(
activities=activities,
y=output,
x=input,
loss=loss_id,
loss_id=loss_id,
solver=ode_solver,
max_t1=max_t1,
dt=dt,
stepsize_controller=stepsize_controller,
record_iters=record_activities,
record_every=record_every
)
activity_norms = (compute_activity_norms(equilib_activities)
if activity_norms else None)
t_max = get_t_max(equilib_activities) if record_activities else None
t_max = get_t_max(equilib_activities) if record_activities else 0
activity_norms = (compute_activity_norms(
activities=tree_map(
lambda act: act[t_max], equilib_activities
)
) if activity_norms else None)
energies = compute_infer_energies(
params=(model, skip_model),
activities_iters=equilib_activities,
Expand All @@ -177,6 +182,9 @@ def make_pc_step(
record_layers=True
)

param_norms = compute_param_norms(
(model, skip_model)
) if param_norms else (None, None)
param_grads = compute_pc_param_grads(
params=(model, skip_model),
activities=tree_map(
Expand All @@ -185,9 +193,9 @@ def make_pc_step(
),
y=output,
x=input,
loss=loss_id
loss_id=loss_id
)
grad_norms = compute_grad_norms(param_grads) if grad_norms else (None, None)
grad_norms = compute_param_norms(param_grads) if grad_norms else (None, None)
updates, opt_state = optim.update(
updates=param_grads,
state=opt_state,
Expand Down Expand Up @@ -217,6 +225,8 @@ def make_pc_step(
"t_max": t_max,
"energies": energies,
"activity_norms": activity_norms,
"model_param_norms": param_norms[0],
"skip_model_param_norms": param_norms[1],
"model_grad_norms": grad_norms[0],
"skip_model_grad_norms": grad_norms[1]
}
Expand Down

0 comments on commit 1244459

Please sign in to comment.