diff --git a/README.md b/README.md index 82f3482..fc0aff2 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,9 @@ pip install --upgrade "jax[cuda12]" ## 📖 [Documentation](https://thebuckleylab.github.io/jpc/) Available at https://thebuckleylab.github.io/jpc/. +## 🧠 PC primer + + ## ⚡️ Quick example Use `jpc.make_pc_step` to update the parameters of any neural network compatible with PC updates (see the [notebook examples diff --git a/docs/_static/pc_primer.png b/docs/_static/pc_primer.png index f6dfa26..bf73046 100644 Binary files a/docs/_static/pc_primer.png and b/docs/_static/pc_primer.png differ diff --git a/docs/_static/pc_primer.svg b/docs/_static/pc_primer.svg deleted file mode 100644 index 7f21c71..0000000 --- a/docs/_static/pc_primer.svg +++ /dev/null @@ -1,299 +0,0 @@ - - - - - - - - - - - -mathematics,andthecorelibraryis<1000linesofcode.Unlikeexistingimplementations, -JPC -leveragesordinarydifferentialequationsolvers(ODE)tointegratethegradientflowinference -dynamicsofPCNs. -JPC -alsoprovidessometheoreticaltoolsthatcanbeusedtostudyandpotentially -identifyproblemswithPCNs. -Therestofthepaperisstructuredasfollows.AfterabriefreviewofPC(§ -2 -),weshowcasesome -empiricalresultsshowingthatasecond-orderODEsolvercanachievesignificantlyfasterruntimes -thanstandardEulerintegrationofthegradientflowPCinferencedynamics,withcomparableperfor- -manceondifferentdatasetsandnetworks(§ -3 -).Wethenexplainthelibrary’scoreimplementation - -4 -),beforeconcludingwithpossibleextensions(§ -5 -). - - -2Predictivecoding:Aprimer -HereweincludeaminimalpresentationofPCnecessarytogetstartedwith -JPC -.Thereader -isreferredto[ -14 -, -8 -, -7 -, -12 -]forreviewsandto[ -1 -]foramoreformaltreatment. -PCNsaretypicallydefinedbyanenergyfunctionwhichisasumofsquaredpredictionerrors -acrosslayers,andwhichforastandardfeedforwardnetworktakestheform -F -= -L -X -` -=1 -|| -z -` - -f -` -( -W -` -z -` - -1 -) -|| -2 -(1) -where -z -` -istheactivityofagivenlayerand -f -` -issomeactivationfunction.Weignore -multipledatapointsandbiasesforsimplicity. -TotrainaPCN,thelastlayerisclampedtosomedata, -z -L -: -= -y -.Thiscouldbealabelfor -classificationoranimageforgeneration,andthesetwosettingsaretypicallyreferredtoas -discriminative -and -generative -PC.Thefirstlayercanalsobefixedtosomedataservingasa -“prior”, -z -0 -: -= -x -,suchasanimageinasupervisedtask.Inunsupervisedtraining,thislayeris -leftfreetovarylikeanyotherhiddenlayer. -Theenergy(Eq. -1 -)isthenminimisedinabi-levelfashion,firstw.r.t.theactivities(inference) -andthenw.r.t.theweights(learning) - - - - -x - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -y - - - - - - -Infer: -argmin -z -` -F -(2) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -Learn: -argmin -W -` -F -(3) -Theinferencedynamicsaregenerallyfirstruntoconvergenceuntil - -z -` - -0 -.Then,atthe -reachedequilibriumoftheactivities,theweightsareupdatedviacommonneuralnetwork -optimiserssuchasstochasticGDorAdam(Eq. -3 -).Thisprocessisrepeatedforeverytraining -step,typicallyforagivendatabatch.InferenceistypicallyperformedbystandardGDonthe -energy,whichcanbeseenastheEulerdiscretisationofthegradientsystem -˙ -z -` -= - -@ -F -/ -@ -z -` -. -JPCsimplyleverageswell-testedODEsolverstointegratethisgradientflow. -3Runtimeefficiency -AcomprehensivebenchmarkingofvarioustypesofPCNwithGDasinferenceoptimiserwasrecently -performedby[ -10 -].Forthisreason,herewefocusonruntimeefficiency,comparingstandardEuler -integrationoftheinferencegradientflowdynamicswithHeun,asecond-orderexplicitRunge–Kutta -method.Notethat,asasecond-ordermethod,HeunhasahighercomputationalcostthanEuler; -however,itcouldstillbefasterifitrequiressignificantlyfewerstepstoconverge. -Thesolverswerecomparedonfeedforwardnetworkstrainedtoclassifystandardimagedatasets, -withdifferentnumberofhiddenlayers -H -2 -{ -3 -, -5 -, -10 -} -.Becauseourgoalwastospecifically -2 - - - diff --git a/docs/index.md b/docs/index.md index 94d97b3..3da3cd6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -40,6 +40,9 @@ here). pip install --upgrade "jax[cuda12]" ``` +## 🧠 PC primer + + ## ⚡️ Quick example Use `jpc.make_pc_step` to update the parameters of any neural network compatible with PC updates (see [examples](https://thebuckleylab.github.io/jpc/examples/discriminative_pc/))