Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document basis #274

Merged
merged 32 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a6d60d7
add svgs basis
BalzaniEdoardo Dec 2, 2024
4be4c63
added table for basis
BalzaniEdoardo Dec 2, 2024
98eaf25
fix notebook
BalzaniEdoardo Dec 2, 2024
7a13bcd
fix notebook
BalzaniEdoardo Dec 2, 2024
0a53a59
ADDED ORTH exp
BalzaniEdoardo Dec 2, 2024
4777a25
fixed entry
BalzaniEdoardo Dec 2, 2024
bd88be2
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
ba94341
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
927d021
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
c07aced
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 3, 2024
be197c1
add script to gen figs
BalzaniEdoardo Dec 4, 2024
0297a36
merged basis refactor
BalzaniEdoardo Dec 4, 2024
e39b1c0
Fix table layout
BalzaniEdoardo Dec 4, 2024
c49ffc1
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
aafa384
fix background note path
BalzaniEdoardo Dec 4, 2024
0fa061f
fix background note path
BalzaniEdoardo Dec 4, 2024
4d2bcb3
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
44e61cc
fixing links
BalzaniEdoardo Dec 4, 2024
442f506
Merge branch 'basis_refactor_pr1' into document_basis
BalzaniEdoardo Dec 4, 2024
09e0c8e
use plot directive for thumbnail
BalzaniEdoardo Dec 4, 2024
57e2bd2
fixed description of basis
BalzaniEdoardo Dec 5, 2024
bcf72ce
merged basis PR1
BalzaniEdoardo Dec 10, 2024
2eb55a9
Merge branch 'development' into document_basis
BalzaniEdoardo Dec 11, 2024
ddc357b
tweak text in plot 1d
billbrod Dec 11, 2024
e384d10
try to add summary to plot 2d
billbrod Dec 11, 2024
a1e6ecd
correct filename typo
billbrod Dec 11, 2024
acda0f2
small fixes in basis/readme
billbrod Dec 11, 2024
387aeff
Update docs/background/basis/README.md
BalzaniEdoardo Dec 12, 2024
0fa80d4
fix paths and module name
BalzaniEdoardo Dec 12, 2024
77017d4
Merge branch 'document_basis' of github.com:flatironinstitute/nemos i…
BalzaniEdoardo Dec 12, 2024
01e6e7d
linted
BalzaniEdoardo Dec 12, 2024
afdf8a8
added pages in readme subsection
BalzaniEdoardo Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions docs/assets/stylesheets/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,29 @@ html[data-theme=light]{
font-weight: normal;
}

/*!* Style the brackets *!*/
/*span.fn-bracket {*/
/* color: #666; !* Dim the brackets *!*/
/*}*/

/*!* Style the links within the footnotes *!*/
/*aside.footnote a {*/
/* color: #007BFF; !* Blue link color *!*/
/* text-decoration: none; !* Remove underline *!*/
/*}*/

/*aside.footnote a:hover {*/
/* text-decoration: underline; !* Add underline on hover *!*/
/*}*/
table.table-center{
text-align: center;
}

#table-basis {
table-layout: auto;
width: 100%;
}

#table-basis th:nth-child(1), #table-basis td:nth-child(1) {
width: 22%;
}

#table-basis th:nth-child(2), #table-basis td:nth-child(2) {
width: 22%;
}
#table-basis th:nth-child(3), #table-basis td:nth-child(3) {
width: 10%;
}

#table-basis th:nth-child(4), #table-basis td:nth-child(4) {
width: 20%;
}
#table-basis th:nth-child(5), #table-basis td:nth-child(5) {
width: 10%;
}
23 changes: 7 additions & 16 deletions docs/background/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,19 @@ plot_00_conceptual_intro.md

:::{grid-item-card}

<figure>
<img src="../_static/thumbnails/background/plot_01_1D_basis_function.svg" style="height: 100px", alt="One-Dimensional Basis."/>
</figure>
```{eval-rst}

```{toctree}
:maxdepth: 2

plot_01_1D_basis_function.md
.. plot:: scripts/basis_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 100px
```
:::

:::{grid-item-card}

<figure>
<img src="../_static/thumbnails/background/plot_02_ND_basis_function.svg" style="height: 100px", alt="N-Dimensional Basis."/>
</figure>

```{toctree}
:maxdepth: 2
:maxdepth: 3

plot_02_ND_basis_function.md
basis/README.md
```

:::

:::{grid-item-card}
Expand Down
134 changes: 134 additions & 0 deletions docs/background/basis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Basis Function

(table_basis)=
```{eval-rst}

.. role:: raw-html(raw)
:format: html

.. list-table::
:header-rows: 1
:name: table-basis
:align: center

* - **Basis**
- **Kernel Visualization**
- **Examples**
- **Evaluation/Convolution**
- **Preferred Mode**
* - **B-Spline**
- .. plot:: scripts/basis_figs.py plot_bspline
:show-source-link: False
:height: 80px
- :ref:`Grid cells <grid_cells_nemos>`
- :class:`~nemos.basis.BSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.BSplineConv`
- 🟢 Eval
* - **Cyclic B-Spline**
- .. plot:: scripts/basis_figs.py plot_cyclic_bspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.CyclicBSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.CyclicBSplineConv`
- 🟢 Eval
* - **M-Spline**
- .. plot:: scripts/basis_figs.py plot_mspline
:show-source-link: False
:height: 80px
- :ref:`Place cells <basis_eval_place_cells>`
- :class:`~nemos.basis.MSplineEval` :raw-html:`<br />`
:class:`~nemos.basis.MSplineConv`
- 🟢 Eval
* - **Linearly Spaced Raised Cosine**
- .. plot:: scripts/basis_figs.py plot_raised_cosine_linear
:show-source-link: False
:height: 80px
-
- :class:`~nemos.basis.RaisedCosineLinearEval` :raw-html:`<br />`
:class:`~nemos.basis.RaisedCosineLinearConv`
- 🟢 Eval
* - **Log Spaced Raised Cosine**
- .. plot:: scripts/basis_figs.py plot_raised_cosine_log
:show-source-link: False
:height: 80px
- :ref:`Head Direction <head_direction_reducing_dimensionality>`
- :class:`~nemos.basis.RaisedCosineLogEval` :raw-html:`<br />`
:class:`~nemos.basis.RaisedCosineLogConv`
- 🔵 Conv
* - **Orthogonalized Exponential Decays**
- .. plot:: scripts/basis_figs.py plot_orth_exp_basis
:show-source-link: False
:height: 80px
-
- :class:`~nemos.basis.OrthExponentialEval` :raw-html:`<br />`
:class:`~nemos.basis.OrthExponentialConv`
- 🟢 Eval
```

## Overview

A basis function is a collection of simple building blocks—functions that, when combined (weighted and summed together), can represent more complex, non-linear relationships. Think of them as tools for constructing predictors in GLMs, helping to model:

1. **Non-linear mappings** between task variables (like velocity or position) and firing rates.
2. **Linear temporal effects**, such as spike history, neuron-to-neuron couplings, or how stimuli are integrated over time.

In a GLM, we assume a non-linear mapping exists between task variables and neuronal firing rates. This mapping isn’t something we can directly observe—what we do see are the inputs (task covariates) and the resulting neural activity. The challenge is to infer a "good" approximation of this hidden relationship.

Basis functions help simplify this process by representing the non-linearity as a weighted sum of fixed functions, $\psi_1(x), \dots, \psi_n(x)$, with weights $\alpha_1, \dots, \alpha_n$. Mathematically:

$$
f(x) \approx \alpha_1 \psi_1(x) + \dots + \alpha_n \psi_n(x)
$$

Here, $\approx$ means "approximately equal".

Instead of tackling the hard problem of learning an unknown function $f(x)$ directly, we reduce it to the simpler task of learning the weights $\{\alpha_i\}$. This preserves convexity, resulting in a much simpler optimization problem.


## Basis in NeMoS

NeMoS provides a variety of basis functions (see the [table](table_basis) above). For each basis type, there are two dedicated classes of objects, corresponding to the two uses described above:

- **Eval basis objects**: For representing non-linear mappings between task variables and outputs. These objects all have names ending with `Eval`.
- **Conv basis objects**: For linear temporal effects. These objects all have names ending with `Conv`.

`Eval` and `Conv` objects can be combined to construct multi-dimensional basis functions, enabling [complex feature construction](composing_basis_function).

## Learn More

::::{grid} 1 2 2 2

:::{grid-item-card}

```{eval-rst}

.. plot:: scripts/basis_figs.py plot_1d_basis_thumbnail
:show-source-link: False
:height: 100px
```

```{toctree}
:maxdepth: 2

plot_01_1D_basis_function.md
```
:::

:::{grid-item-card}

```{eval-rst}

.. plot:: scripts/basis_figs.py plot_nd_basis_thumbnail
:show-source-link: False
:height: 100px
```

```{toctree}
:maxdepth: 2

plot_02_ND_basis_function.md
```
:::

::::
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ warnings.filterwarnings(

## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
The hyperparameters required to initialize this class are:
We'll start by defining a 1D basis function object of the type [`BSplineEval`](nemos.basis.BSplineEval).
The hyperparameters needed to initialize this class are:

- The number of basis functions, which should be a positive integer.
- The order of the spline, which should be an integer greater than 1.
- The number of basis functions, which should be a positive integer (required).
- The order of the spline, which should be an integer greater than 1 (optional, default 4 for a cubic spline).

```{code-cell} ipython3
import matplotlib.pylab as plt
Expand Down Expand Up @@ -81,46 +81,26 @@ plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```

```{code-cell} ipython3
:tags: [hide-input]

# save image for thumbnail
from pathlib import Path
import os

root = os.environ.get("READTHEDOCS_OUTPUT")
if root:
path = Path(root) / "html/_static/thumbnails/background"
# if local store in ../_build/html/...
else:
path = Path("../_build/html/_static/thumbnails/background")

# make sure the folder exists if run from build
if root or Path("../_build/html/_static").exists():
path.mkdir(parents=True, exist_ok=True)

if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")
```

## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:
## Computing Features
All bases in the `nemos.basis` module perform a transformation of one or more time series into a set of features. This operation is always carried out by the method [`compute_features`](nemos.basis._basis.Basis.compute_features).
We can group the bases into two categories depending on the type of transformation that [`compute_features`](nemos.basis._basis.Basis.compute_features) applies:

1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.
1. **Evaluation Bases**: These bases use `compute_features` to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.

2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`.
2. **Convolution Bases**: These bases use `compute_features` to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv", such as `BSplineConv`.

Let's see how this two modalities operate.
Let's see how these two categories operate:

```{code-cell} ipython3
eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100)
eval_mode = nmo.basis.BSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.BSplineConv(n_basis_funcs=n_basis, window_size=100)

# define an input
angles = np.linspace(0, np.pi*4, 201)
y = np.cos(angles)

# compute features in the two modalities
# compute features
eval_feature = eval_mode.compute_features(y)
conv_feature = conv_mode.compute_features(y)

Expand Down Expand Up @@ -153,19 +133,21 @@ Convolution is performed in "valid" mode, and then NaN-padded. The default behav
is padding left, which makes the output feature causal.
This is why the first half of the `conv_feature` is full of NaNs and appears as white.
If you want to learn more about convolutions, as well as how and when to change defaults
check out the tutorial on [1D convolutions](plot_03_1D_convolution).
check out the tutorial on [1D convolutions](convolution_background).
:::


Plotting the Basis Function Elements:
Plotting the Basis Function Elements
--------------------------------------
We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
particularly evident when working with multidimensional basis functions. You can find more details and visual
background in the
[2D basis elements plotting section](plotting-2d-additive-basis-elements).
the equi-spaced samples along with the evaluated basis functions.

:::{admonition} Note

The array returned by `evaluate_on_grid(n_samples)` is the same as the kernel that is used by the Conv bases initialized with `window_sizes=n_samples`!

:::

```{code-cell} ipython3
# Call evaluate on grid on 100 sample points to generate samples and evaluate the basis at those samples
Expand All @@ -179,12 +161,13 @@ plt.plot(equispaced_samples, eval_basis)
plt.show()
```

The benefits of using `evaluate_on_grid` become particularly evident when working with multidimensional basis functions. You can find more details in the [2D basis elements plotting section](plotting-2d-additive-basis-elements).

## Setting the basis support (Eval only)
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions).
parameter at initialization of Eval bases.
Evaluating the basis at any sample outside the bounds will result in a NaN.


Expand All @@ -210,26 +193,3 @@ axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```

Other Basis Types
-----------------
Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
please refer to the [API Guide](nemos_basis). After instantiation, all classes
share the same syntax for basis evaluation. The following is an example of how to instantiate and
evaluate a log-spaced cosine raised function basis.


```{code-cell} ipython3
# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50)

# Evaluate the raised cosine basis at the equi-spaced sample points
# (same method in all Basis elements)
samples, eval_basis = raised_cosine_log.evaluate_on_grid(100)

# Plot the evaluated log-spaced raised cosine basis
plt.figure()
plt.title(f"Log-spaced Raised Cosine basis with {eval_basis.shape[1]} elements")
plt.plot(samples, eval_basis)
plt.show()
```
Loading
Loading