Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document basis #274

Merged
merged 32 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6d60d7
add svgs basis
BalzaniEdoardo Dec 2, 2024
4be4c63
added table for basis
BalzaniEdoardo Dec 2, 2024
98eaf25
fix notebook
BalzaniEdoardo Dec 2, 2024
7a13bcd
fix notebook
BalzaniEdoardo Dec 2, 2024
0a53a59
ADDED ORTH exp
BalzaniEdoardo Dec 2, 2024
4777a25
fixed entry
BalzaniEdoardo Dec 2, 2024
bd88be2
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
ba94341
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
927d021
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
c07aced
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
be197c1
add script to gen figs
BalzaniEdoardo Dec 4, 2024
0297a36
merged basis refactor
BalzaniEdoardo Dec 4, 2024
e39b1c0
Fix table layout
BalzaniEdoardo Dec 4, 2024
c49ffc1
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
aafa384
fix background note path
BalzaniEdoardo Dec 4, 2024
0fa061f
fix background note path
BalzaniEdoardo Dec 4, 2024
4d2bcb3
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
44e61cc
fixing links
BalzaniEdoardo Dec 4, 2024
442f506
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
09e0c8e
use plot directive for thumbnail
BalzaniEdoardo Dec 4, 2024
57e2bd2
fixed description of basis
BalzaniEdoardo Dec 5, 2024
bcf72ce
merged basis PR1
BalzaniEdoardo Dec 10, 2024
2eb55a9
Merge branch 'development' into document_basis
BalzaniEdoardo Dec 11, 2024
ddc357b
tweak text in plot 1d
billbrod Dec 11, 2024
e384d10
try to add summary to plot 2d
billbrod Dec 11, 2024
a1e6ecd
correct filename typo
billbrod Dec 11, 2024
acda0f2
small fixes in basis/readme
billbrod Dec 11, 2024
387aeff
Update docs/background/basis/README.md
BalzaniEdoardo Dec 12, 2024
0fa80d4
fix paths and module name
BalzaniEdoardo Dec 12, 2024
77017d4
Merge branch 'document_basis' of github.com:flatironinstitute/nemos i…
BalzaniEdoardo Dec 12, 2024
01e6e7d
linted
BalzaniEdoardo Dec 12, 2024
afdf8a8
added pages in readme subsection
BalzaniEdoardo Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions docs/assets/stylesheets/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,29 @@ html[data-theme=light]{
font-weight: normal;
}

/*!* Style the brackets *!*/
/*span.fn-bracket {*/
/* color: #666; !* Dim the brackets *!*/
/*}*/

/*!* Style the links within the footnotes *!*/
/*aside.footnote a {*/
/* color: #007BFF; !* Blue link color *!*/
/* text-decoration: none; !* Remove underline *!*/
/*}*/

/*aside.footnote a:hover {*/
/* text-decoration: underline; !* Add underline on hover *!*/
/*}*/
table.table-center{
text-align: center;
}

#table-basis {
table-layout: auto;
width: 100%;
}

#table-basis th:nth-child(1), #table-basis td:nth-child(1) {
width: 22%;
}

#table-basis th:nth-child(2), #table-basis td:nth-child(2) {
width: 22%;
}
#table-basis th:nth-child(3), #table-basis td:nth-child(3) {
width: 10%;
}

#table-basis th:nth-child(4), #table-basis td:nth-child(4) {
width: 20%;
}
#table-basis th:nth-child(5), #table-basis td:nth-child(5) {
width: 10%;
}
20 changes: 5 additions & 15 deletions docs/background/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,17 @@ plot_00_conceptual_intro.md

:::{grid-item-card}

<figure>
<img src="../_static/thumbnails/background/plot_01_1D_basis_function.svg" style="height: 100px", alt="One-Dimensional Basis."/>
</figure>

```{toctree}
:maxdepth: 2
```{eval-rst}

plot_01_1D_basis_function.md
.. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 100px
```
:::

:::{grid-item-card}

<figure>
<img src="../_static/thumbnails/background/plot_02_ND_basis_function.svg" style="height: 100px", alt="N-Dimensional Basis."/>
</figure>

```{toctree}
:maxdepth: 2

plot_02_ND_basis_function.md
basis/README.md
```
:::

Expand Down
128 changes: 128 additions & 0 deletions docs/background/basis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Basis Function

(table_basis)=
```{eval-rst}

.. role:: raw-html(raw)
:format: html

.. list-table::
:header-rows: 1
:name: table-basis
:align: center

* - **Basis**
- **Kernel Visualization**
- **Examples**
- **Evaluation/Convolution**
- **Preferred Mode**
* - **B-Spline**
- .. plot:: scripts/basis_table_figs.py plot_bspline
:show-source-link: False
:height: 80px
- :ref:`Grid cells <grid_cells_nemos>`
- :class:`~nemos.basis.BSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.BSplineConv`
- 🟢 Eval
* - **Cyclic B-Spline**
- .. plot:: scripts/basis_table_figs.py plot_cyclic_bspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.CyclicBSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.CyclicBSplineConv`
- 🟢 Eval
* - **M-Spline**
- .. plot:: scripts/basis_table_figs.py plot_mspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.MSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.MSplineConv`
- 🟢 Eval
* - **Linearly Spaced Raised Cosine**
- .. plot:: scripts/basis_table_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 80px
-
- :class:`~nemos.basis.RaisedCosineLinearEval` :raw-html:`<br />`
:class:`~nemos.basis.RaisedCosineLinearConv`
- 🟢 Eval
* - **Log Spaced Raised Cosine**
- .. plot:: scripts/basis_table_figs.py plot_raised_cosine_log
:show-source-link: False
:height: 80px
- :ref:`Head Direction <head_direction_reducing_dimensionality>`
- :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`<br />`
:class:`nemos.basis.RaisedCosineLogConv`
- 🔵 Conv
* - **Orthogonalized Exponential Decays**
- .. plot:: scripts/basis_table_figs.py plot_orth_exp_basis
:show-source-link: False
:height: 80px
-
- :class:`~nemos.basis.OrthExponentialEval` :raw-html:`<br />`
:class:`~nemos.basis.OrthExponentialConv`
- 🟢 Eval
```

## Overview

A basis function is a collection of simple building blocks—functions that, when combined (weighted and summed together), can represent more complex, non-linear relationships. Think of them as tools for constructing predictors in GLMs, helping to model:

1. **Non-linear mappings** between task variables (like velocity or position) and firing rates.
2. **Linear temporal effects**, such as spike history, neuron-to-neuron couplings, or how stimuli are integrated over time.

In a GLM, we assume a non-linear mapping exists between task variables and neuronal firing rates. This mapping isn’t something we can directly observe—what we do see are the inputs (task covariates) and the resulting neural activity. The challenge is to infer a "good" approximation of this hidden relationship.

Basis functions help simplify this process by representing the non-linearity as a weighted sum of fixed functions, $\psi_1(x), \dots, \psi_n(x)$, with weights $\alpha_1, \dots, \alpha_n$. Mathematically:

$$
f(x) \approx \alpha_1 \psi_1(x) + \dots + \alpha_n \psi_n(x)
$$

Here, $\approx$ means "approximately equal".

Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$.
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved


## Basis in NeMoS

NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two key uses described in the overview:

- **Eval-basis objects**: For representing non-linear mappings between task variables and outputs. These objects are identified by names starting with `Eval`.
- **Conv-basis objects**: For linear temporal effects. These objects are identified by names starting with `Conv`.

`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling complex feature construction.

## Learn More

::::{grid} 1 2 2 2

:::{grid-item-card}

<figure>
<img src="../../_static/thumbnails/background/plot_01_1D_basis_function.svg" style="height: 100px", alt="One-Dimensional Basis."/>
</figure>

```{toctree}
:maxdepth: 2

plot_01_1D_basis_function.md
```
:::

:::{grid-item-card}

<figure>
<img src="../../_static/thumbnails/background/plot_02_ND_basis_function.svg" style="height: 100px", alt="N-Dimensional Basis."/>
</figure>

```{toctree}
:maxdepth: 2

plot_02_ND_basis_function.md
```
:::

::::
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ warnings.filterwarnings(
## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
The hyperparameters required to initialize this class are:
The hyperparameters needed to initialize this class are:

- The number of basis functions, which should be a positive integer.
- The order of the spline, which should be an integer greater than 1.
- The number of basis functions, which should be a positive integer (required).
- The order of the spline, which should be an integer greater than 1 (optional, default 4 for a cubic spline).

```{code-cell} ipython3
import matplotlib.pylab as plt
Expand Down Expand Up @@ -93,22 +93,27 @@ if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../_build/html/_static/thumbnails/background")
path = Path("../../_build/html/_static/thumbnails/background")

# make sure the folder exists if run from build
if root or Path("../_build/html/_static").exists():
if root or Path("../../_build/html/_static").exists():
path.mkdir(parents=True, exist_ok=True)

if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")


print(path.resolve(), path.exists())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this section still necessary (with the plot_directive)? and, regardless, we probably don't need to print out the path right, that's just for internal debugging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it!

```

## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:
All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features).
We can be group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies:

1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ends with "Eval," such as `BSplineEval`.

1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.
2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`.

2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`.

Let's see how this two modalities operate.

Expand Down Expand Up @@ -153,7 +158,7 @@ Convolution is performed in "valid" mode, and then NaN-padded. The default behav
is padding left, which makes the output feature causal.
This is why the first half of the `conv_feature` is full of NaNs and appears as white.
If you want to learn more about convolutions, as well as how and when to change defaults
check out the tutorial on [1D convolutions](plot_03_1D_convolution).
check out the tutorial on [1D convolutions](convolution_background).
:::


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../_build/html/_static/thumbnails/background")
path = Path("../../_build/html/_static/thumbnails/background")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same point about whether we need this anymore

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this too


# make sure the folder exists if run from build
if root or Path("../_build/html/_static").exists():
Expand Down
11 changes: 10 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
'sphinx.ext.mathjax',
'sphinx_autodoc_typehints',
'sphinx_togglebutton',
'matplotlib.sphinxext.plot_directive',
"matplotlib.sphinxext.mathmpl",
]

myst_enable_extensions = [
Expand Down Expand Up @@ -121,7 +123,11 @@
"logo": {
"image_light": "_static/NeMoS_Logo_CMYK_Full.svg",
"image_dark": "_static/NeMoS_Logo_CMYK_White.svg",
}
},
"secondary_sidebar_items": {
"**": ["page-toc", "sourcelink"],
"background/basis/README": [],
},
}

html_sidebars = {
Expand Down Expand Up @@ -160,3 +166,6 @@
nb_execution_excludepatterns = ["tutorials/**", "how_to_guide/**", "background/**"]

viewcode_follow_imported_members = True

# option for mpl extension
plot_html_show_formats = False
56 changes: 56 additions & 0 deletions docs/scripts/basis_table_figs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import matplotlib.pyplot as plt
import numpy as np

import nemos as nmo
from nemos._inspect_utils import trim_kwargs

plt.rcParams.update(
{
"figure.dpi": 300,
}
)

KWARGS = dict(
n_basis_funcs=10,
decay_rates=np.arange(1, 10 + 1),
enforce_decay_to_zero=True,
order=4,
width=2,
)


def plot_basis(cls):
cls_params = cls._get_param_names()
new_kwargs = trim_kwargs(cls, KWARGS.copy(), {cls.__name__: cls_params})
bas = cls(**new_kwargs)
fig, ax = plt.subplots(1, 1, figsize=(5, 2.5))
ax.plot(*bas.evaluate_on_grid(300), lw=4)
for side in ["left", "right", "top", "bottom"]:
ax.spines[side].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
plt.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)


def plot_raised_cosine_linear():
plot_basis(nmo.basis.RaisedCosineLinearEval)


def plot_raised_cosine_log():
plot_basis(nmo.basis.RaisedCosineLogEval)


def plot_mspline():
plot_basis(nmo.basis.MSplineEval)


def plot_bspline():
plot_basis(nmo.basis.BSplineEval)


def plot_cyclic_bspline():
plot_basis(nmo.basis.CyclicBSplineEval)


def plot_orth_exp_basis():
plot_basis(nmo.basis.OrthExponentialEval)
1 change: 1 addition & 0 deletions docs/tutorials/plot_02_head_direction.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ worst if we needed a finer temporal resolution, such 1ms time bins
(which would require 800 coefficients instead of 80).
What can we do to mitigate over-fitting now?

(head_direction_reducing_dimensionality)=
#### Reducing feature dimensionality
One way to proceed is to find a lower-dimensional representation of the response
by parametrizing the decay effect. For instance, we could try to model it
Expand Down
11 changes: 8 additions & 3 deletions docs/tutorials/plot_03_grid_cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ for i in range(len(spikes)):
plt.tight_layout()
```


(grid_cells_nemos)=
## NeMoS
It's time to use NeMoS.
Let's try to predict the spikes as a function of position and see if we can generate better tuning curves
Expand All @@ -146,9 +148,9 @@ We can define a two-dimensional basis for position by multiplying two one-dimens
see [here](composing_basis_function) for more details.

```{code-cell} ipython3
basis_2d = nmo.basis.RaisedCosineLinearEval(
basis_2d = nmo.basis.BSplineEval(
n_basis_funcs=10
) * nmo.basis.RaisedCosineLinearEval(n_basis_funcs=10)
) * nmo.basis.BSplineEval(n_basis_funcs=10)
```

Let's see what a few basis look like. Here we evaluate it on a 100 x 100 grid.
Expand Down Expand Up @@ -219,7 +221,10 @@ Here we will focus on the last neuron (neuron 7) who has a nice grid pattern
```{code-cell} ipython3
model = nmo.glm.GLM(
regularizer="Ridge",
regularizer_strength=0.001
regularizer_strength=0.0001,
# lowering the tolerance means that the solution will be closer to the optimum
# (at the cost of increasing execution time)
solver_kwargs=dict(tol=10**-12),
)
```

Expand Down
Loading
Loading