Hide code cell source

%load_ext autoreload
%autoreload 2

%matplotlib inline
import warnings

warnings.filterwarnings(
    "ignore",
    message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
    category=UserWarning,
)

warnings.filterwarnings(
    "ignore",
    message="Ignoring cached namespace 'core'",
    category=UserWarning,
)

warnings.filterwarnings(
    "ignore",
    message=(
        "invalid value encountered in div "
    ),
    category=RuntimeWarning,
)

Download

This notebook can be downloaded as place_cells.ipynb. See the button at the top right to download as markdown or pdf.

Model selection#

In this session we will keep working on the hippocampal place field recordings. In particular, we will learn how to model neural responses to multiple predictors: position, speed and theta phase. We will also learn how we can leverage GLM to assess which predictor is more informative for explaining the observed activity.

Data for this notebook comes from recordings in the mouse hippocampus while the mouse runs on a linear track, which we [explored yesterday](../day1/phase_precession-users.md).

Learning objectives#

  • Review how to use pynapple to analyze neuronal tuning

  • Learn how to combine NeMoS basis objects for modeling multiple predictors

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap

import nemos as nmo

# some helper plotting functions
from nemos import _documentation_utils as doc_plots
import workshop_utils

# configure plots some
plt.style.use(nmo.styles.plot_style)

import workshop_utils

from sklearn import model_selection
from sklearn import pipeline

# shut down jax to numpy conversion warning
nap.nap_config.suppress_conversion_warnings = True

# during development, set this to a lower number so everything runs faster. 
cv_folds = 5
WARNING:2025-10-24 20:33:35,479:jax._src.xla_bridge:850: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Pynapple#

- Load the data using pynapple.
path = workshop_utils.fetch_data("Achilles_10252013_EEG.nwb")
data = nap.load_file(path)
data
Achilles_10252013_EEG
┍━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys        │ Type        │
┝━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units       │ TsGroup     │
│ rem         │ IntervalSet │
│ nrem        │ IntervalSet │
│ forward_ep  │ IntervalSet │
│ eeg         │ TsdFrame    │
│ theta_phase │ Tsd         │
│ position    │ Tsd         │
┕━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙
- Extract the spike times and mouse position.
spikes = data["units"]
position = data["position"]

For today, we’re only going to focus on the times when the animal was traversing the linear track. This is a pynapple IntervalSet, so we can use it to restrict our other variables:

  • Restrict data to when animal was traversing the linear track.

position = position.restrict(data["forward_ep"])
spikes = spikes.restrict(data["forward_ep"])

The recording contains both inhibitory and excitatory neurons. Here we will focus of the excitatory cells with firing above 0.3 Hz.

  • Restrict neurons to only excitatory neurons, discarding neurons with a low-firing rate.

spikes = spikes.getby_category("cell_type")["pE"]
spikes = spikes.getby_threshold("rate", 0.3)

Place fields#

By plotting the neuronal firing rate as a function of position, we can see that these neurons are all tuned for position: they fire in a specific location on the track.

  • Visualize the place fields: neuronal firing rate as a function of position.

place_fields = nap.compute_tuning_curves(spikes, position, bins=50, epochs=position.time_support, feature_names=["distance"])
workshop_utils.plot_place_fields(place_fields)
../../_images/e6eaa2cc227feeace3fc1b83b68ee10a929cdb38e7a293ad5adc7e78088f3fa5.png

To decrease computation time, we’re going to spend the rest of the notebook focusing on the neurons highlighted above. We’re also going to bin spikes at 100 Hz and up-sample the position to match that temporal resolution.

  • For speed, we’re only going to investigate the three neurons highlighted above.

  • Bin spikes to counts at 100 Hz.

  • Interpolate position to match spike resolution.

neurons = [82, 92, 220]
place_fields = place_fields.sel(unit=neurons)
spikes = spikes[neurons]
bin_size = .01
count = spikes.count(bin_size, ep=position.time_support)
position = position.interpolate(count, ep=count.time_support)
print(count.shape)
print(position.shape)
(19237, 3)
(19237,)

Speed modulation#

The speed at which the animal traverse the field is not homogeneous. Does it influence the firing rate of hippocampal neurons? We can compute tuning curves for speed as well as average speed across the maze. In the next block, we compute the speed of the animal for each epoch (i.e. crossing of the linear track) by doing the difference of two consecutive position multiplied by the sampling rate of the position.

  • Compute animal’s speed for each epoch.

speed = []
# Analyzing each epoch separately avoids edge effects.
for s, e in position.time_support.values: 
    pos_ep = position.get(s, e)
    # Absolute difference of two consecutive points
    speed_ep = np.abs(np.diff(pos_ep)) 
    # Padding the edge so that the size is the same as the position/spike counts
    speed_ep = np.pad(speed_ep, [0, 1], mode="edge") 
    # Converting to cm/s 
    speed_ep = speed_ep * position.rate
    speed.append(speed_ep)

speed = nap.Tsd(t=position.t, d=np.hstack(speed), time_support=position.time_support)
print(speed.shape)
(19237,)

Now that we have the speed of the animal, we can compute the tuning curves for speed modulation. Here we call pynapple compute_tuning_curves:

tc_speed = nap.compute_tuning_curves(spikes, speed, bins=20, epochs=speed.time_support, feature_names=["speed"])
  • Visualize the position and speed tuning for these neurons.

fig = workshop_utils.plot_position_speed(position, speed, place_fields, tc_speed, neurons);
../../_images/6bd34b8978faabaa98cb852c410432b4c4c4cfd028202e2e8652f480713364c9.png

These neurons show a strong modulation of firing rate as a function of speed but we also notice that the animal, on average, accelerates when traversing the field. Is the speed tuning we observe a true modulation or spurious correlation caused by traversing the place field at different speeds? We can use NeMoS to model the activity and give the position and the speed as input variable.

These neurons all show both position and speed tuning, and we see that the animal’s speed and position are highly correlated. We’re going to build a GLM to predict neuronal firing rate – which variable should we use? Is the speed tuning just epiphenomenal?

NeMoS#

Basis evaluation#

As we’ve seen before, we will use basis objects to represent the input values. In previous tutorials, we’ve used the Conv basis objects to represent the time-dependent effects we were looking to capture. Here, we’re trying to capture the non-linear relationship between our input variables and firing rate, so we want the Eval objects. In these circumstances, you should look at the tuning you’re trying to capture and compare to the basis kernels (visualized in NeMoS docs): you want your tuning to be capturable by a linear combination of them.

In this case, several of these would probably work; we will use MSplineEval for both, though with different numbers of basis functions.

Additionally, since we have two different inputs, we’ll need two separate basis objects.

Note

Later in this notebook, we’ll show how to cross-validate across basis identity, which you can use to choose the basis.

  • why basis?

    • without basis:

      • either the GLM says that firing rate increases exponentially as position or speed increases, which is fairly nonsensical,

      • or we have to fit the weight separately for each position or speed, which is really high-dim

    • so, basis allows us to reduce dimensionality, capture non-linear modulation of firing rate (in this case, tuning)

  • why eval?

    • basis objects have two modes:

    • conv, like we’ve seen, for capturing time-dependent effects

    • eval, for capturing non-linear modulation / tuning

  • why MSpline?

    • when deciding on eval basis, look at the tuning you want to capture, compare to the kernels: you want your tuning to be capturable by a linear combination of these

    • in cases like this, many possible basis objects we could use here and what I’ll show you in a bit will allow you to determine which to use in principled manner

    • MSpline, BSpline, RaisedCosineLinear : all would let you capture this

    • weird choices:

      • cyclic bspline, except maybe for position? if end and start are the same

      • RaisedCosineLog (don’t want the stretching)

      • orthogonalized exponential (specialized for…)

      • identity / history (too basic)

  • Create a separate basis object for each model input.

  • Visualize the basis objects.

position_basis = nmo.basis.MSplineEval(n_basis_funcs=10)
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15)
workshop_utils.plot_pos_speed_bases(position_basis, speed_basis)
../../_images/336976f50954357f9314e462ab267fbd4244cd2eda40e45c3416602e3042dff3.png

However, now we have an issue: in all our previous examples, we had a single basis object, which took a single input to produce a single array which we then passed to the GLM object as the design matrix. What do we do when we have multiple basis objects?

We could call basis.compute_features() for each basis separately and then concatenated the outputs, but then we have to remember the order we concatenated them in and that behavior gets unwieldy as we add more bases.

Instead, NeMoS allows us to combine multiple basis objects into a single “additive basis”, which we can pass all of our inputs to in order to produce a single design matrix:

  • Combine the two basis objects into a single “additive basis”

# equivalent to calling nmo.basis.AdditiveBasis(position_basis, speed_basis)
basis = position_basis + speed_basis
  • Create the design matrix!

  • Notice that, since we passed the basis pynapple objects, we got one back, preserving the time stamps.

  • X has the same number of time points as our input position and speed, but 25 columns. The columns come from n_basis_funcs from each basis (10 for position, 15 for speed).

X = basis.compute_features(position, speed)
X
Time (s)                0           1            2            3    4  ...
---------------  --------  ----------  -----------  -----------  ---  -----
18193.603802655  0.162854  0.0063002   5.33261e-05  1.12053e-07    0  ...
18193.613802655  0.159556  0.00790116  8.5239e-05   2.27854e-07    0  ...
18193.623802655  0.156303  0.00946868  0.000124422  4.04333e-07    0  ...
18193.633802655  0.151106  0.0119478   0.000203448  8.54051e-07    0  ...
18193.643802655  0.145902  0.014398    0.000303689  1.57398e-06    0  ...
18193.653802655  0.141967  0.0162281   0.000394145  2.34623e-06    0  ...
18193.663802655  0.139957  0.0171555   0.000445419  2.8306e-06     0  ...
...                                                                   ...
20123.332682821  0         0           0            0              0  ...
20123.342682821  0         0           0            0              0  ...
20123.352682821  0         0           0            0              0  ...
20123.362682821  0         0           0            0              0  ...
20123.372682821  0         0           0            0              0  ...
20123.382682821  0         0           0            0              0  ...
20123.392682821  0         0           0            0              0  ...
dtype: float64, shape: (19237, 25)

Model learning#

As we’ve done before, we can now use the Poisson GLM from NeMoS to learn the combined model:

  • Initialize PopulationGLM

  • Use the “LBFGS” solver and pass {"tol": 1e-12} to solver_kwargs.

  • Fit the data, passing the design matrix and spike counts to the glm object.

glm = nmo.glm.PopulationGLM(
    solver_kwargs={"tol": 1e-12},
    solver_name="LBFGS",
)

glm.fit(X, count)
PopulationGLM(
    observation_model=PoissonObservations(),
    inverse_link_function=exp,
    regularizer=UnRegularized(),
    solver_name='LBFGS',
    solver_kwargs={'tol': 1e-12}
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Prediction#

Let’s check first if our model can accurately predict the tuning curves we displayed above. We can use the predict function of NeMoS and then compute new tuning curves

  • Use predict to check whether our GLM has captured each neuron’s speed and position tuning.

  • Remember to convert the predicted firing rate to spikes per second!

# predict the model's firing rate
predicted_rate = glm.predict(X) / bin_size

# same shape as the counts we were trying to predict
print(predicted_rate.shape, count.shape)

# compute the position and speed tuning curves using the predicted firing rate.
glm_pos = nap.compute_tuning_curves(predicted_rate, position, bins=50, epochs=position.time_support, feature_names=["position"])
glm_speed = nap.compute_tuning_curves(predicted_rate, speed, bins=30, epochs=speed.time_support, feature_names=["speed"])
(19237, 3) (19237, 3)
  • Compare model and data tuning curves together. The model did a pretty good job!

workshop_utils.plot_position_speed_tuning(place_fields, tc_speed, glm_pos, glm_speed);
../../_images/2a01e70c766ac7fe9fc2500758f06849fa9661315b2fb51a77082381c59faaad.png

We can see that this model does a good job capturing both the position and the speed. In the rest of this notebook, we’re going to investigate all the scientific decisions that we swept under the rug: should we regularize the model? what basis should we use? do we need both inputs?

To make our lives easier, let’s create a helper function that wraps the above lines, because we’re going to be visualizing our model predictions a lot.

def visualize_model_predictions(glm, X):
    # predict the model's firing rate
    predicted_rate = glm.predict(X) / bin_size

    # compute the position and speed tuning curves using the predicted firing rate.
    glm_pos = nap.compute_tuning_curves(predicted_rate, position, bins=50, epochs=position.time_support, feature_names=["position"])
    glm_speed = nap.compute_tuning_curves(predicted_rate, speed, bins=30, epochs=position.time_support, feature_names=["speed"])

    workshop_utils.plot_position_speed_tuning(place_fields, tc_speed, glm_pos, glm_speed);

In our previous analysis of the place field hyppocampal dataset we compared multiple encoding models and tried to figure out which predictor (position, speed or phase) had more explanatory power. In this notebook we will keep going on that effort and learn more principled (and convenient) approaches to model comparison combining NeMoS and scikit-learn.

Learning Objectives#

  • Learn how to use NeMoS objects with scikit-learn for cross-validation

  • Learn how to use NeMoS objects with scikit-learn pipelines

  • Learn how to use cross-validation to perform model and feature selection

Scikit-learn#

How to know when to regularize?#

In the last session, Edoardo fit the all-to-all connectivity of the head-tuning dataset using the Ridge regularizer. In the model above, we’re not using any regularization? Why is that?

We have far fewer parameters here then in the last example. However, how do you know if you need regularization or not? One thing you can do is use cross-validation to see whether model performance improves with regularization (behind the scenes, this is what we did!). We’ll walk through how to do that now.

Instead of implementing our own cross-validation machinery, the developers of nemos decided that we should write the package to be compliant with scikit-learn, the canonical machine learning python library. Our models are all what scikit-learn calls “estimators”, which means they have .fit, .score. and .predict methods. Thus, we can use them with scikit-learn’s objects out of the box.

We’re going to use scikit-learn’s GridSearchCV object, which performs a cross-validated grid search, as Edoardo explained in his presentation.

This object requires an estimator, our glm object here, and param_grid, a dictionary defining what to check. For now, let’s just compare Ridge regularization with no regularization:

  • How do we decide when to use regularization?

  • Cross-validation allows you to fairly compare different models on the same dataset.

  • NeMoS makes use of scikit-learn, the standard machine learning library in python.

  • Define parameter grid to search over.

  • Anything not specified in grid will be kept constant.

# define a Ridge GLM
glm = nmo.glm.PopulationGLM(
    regularizer="Ridge",
    solver_kwargs={"tol": 1e-12},
    solver_name="LBFGS",
)
param_grid = {
    "regularizer_strength": [0.0001, 1.],
}
cv = model_selection.GridSearchCV(glm, param_grid, cv=cv_folds)
cv
GridSearchCV(cv=5,
             estimator=PopulationGLM(inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
             param_grid={'regularizer_strength': [0.0001, 1.0]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This will take a bit to run, because we’re fitting the model many times!

  • We interact with this in a very similar way to the glm object.

  • In particular, call fit with same arguments:

cv.fit(X, count)
GridSearchCV(cv=5,
             estimator=PopulationGLM(inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
             param_grid={'regularizer_strength': [0.0001, 1.0]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
  • We got a warning because we didn’t specify the regularizer strength, so we just fell back on default value.

  • Let’s investigate results:

Cross-validation results are stored in a dictionary attribute called cv_results_, which contains a lot of info. Let’s convert that to a pandas dataframe for readability,

import pandas as pd

pd.DataFrame(cv.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_regularizer_strength params split0_test_score split1_test_score split2_test_score split3_test_score split4_test_score mean_test_score std_test_score rank_test_score
0 4.292632 0.901613 0.343284 0.416243 0.0001 {'regularizer_strength': 0.0001} -0.128512 -0.117619 -0.160407 -0.130696 -0.124069 -0.132261 0.014766 1
1 1.697351 0.112994 0.003119 0.000073 1.0000 {'regularizer_strength': 1.0} -0.137128 -0.128441 -0.173226 -0.138691 -0.131280 -0.141753 0.016175 2

The most informative for us is the 'mean_test_score' key, which shows the average of glm.score on each test-fold. Thus, higher is better, and we can see that the UnRegularized model performs better.

Select basis#

We can do something similar to select the basis. In the above example, I just told you which basis function to use and how many of each. But, in general, you want to select those in a reasonable manner. Cross-validation to the rescue!

Unlike the glm objects, our basis objects are not scikit-learn compatible right out of the box. However, they can be made compatible by using the .to_transformer() method (or, equivalently, by using the TransformerBasis class)

  • You can (and should) do something similar to determine how many basis functions you need for each input.

  • NeMoS basis objects are not scikit-learn-compatible right out of the box.

  • But we have provided a simple method to make them so:

position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position").to_transformer()
# or equivalently:
position_basis = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(n_basis_funcs=10, label="position"))
position_basis
'position': MSplineEval(n_basis_funcs=10, order=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This gives the basis object the transform method, which is equivalent to compute_features. However, transformers have some limits:

  • This gives the basis object the transform method, which is equivalent to compute_features.

  • However, transformers have some limits:

position_basis.transform(position)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[24], line 1
----> 1 position_basis.transform(position)

File ~/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:245, in TransformerBasis.transform(self, X, y)
    208 """
    209 Transform the data using the fitted basis functions.
    210 
   (...)    242 >>> feature_transformed = transformer.transform(X)
    243 """
    244 self._check_initialized(self.basis)
--> 245 self._check_input(X, y)
    246 # transpose does not work with pynapple
    247 # can't use func(*X.T) to unwrap
    248 return self.basis.compute_features(*self._unpack_inputs(X))

File ~/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:553, in TransformerBasis._check_input(self, X, y)
    550     raise ValueError("The input must be a 2-dimensional array.")
    552 elif ndim != 2:
--> 553     raise ValueError(
    554         f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead."
    555     )
    557 if X.shape[1] != sum(self._input_shape_product):
    558     raise ValueError(
    559         f"Input mismatch: expected {sum(self._input_shape_product)} inputs, but got {X.shape[1]} "
    560         f"columns in X.\nTo modify the required number of inputs, call `set_input_shape` before using "
    561         f"`fit` or `fit_transform`."
    562     )

ValueError: X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape (19237,) instead.
  • Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.

Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.

position_basis.transform(position[:, np.newaxis])
Time (s)                0           1            2            3    4  ...
---------------  --------  ----------  -----------  -----------  ---  -----
18193.603802655  0.162854  0.0063002   5.33261e-05  1.12053e-07    0  ...
18193.613802655  0.159556  0.00790116  8.5239e-05   2.27854e-07    0  ...
18193.623802655  0.156303  0.00946868  0.000124422  4.04333e-07    0  ...
18193.633802655  0.151106  0.0119478   0.000203448  8.54051e-07    0  ...
18193.643802655  0.145902  0.014398    0.000303689  1.57398e-06    0  ...
18193.653802655  0.141967  0.0162281   0.000394145  2.34623e-06    0  ...
18193.663802655  0.139957  0.0171555   0.000445419  2.8306e-06     0  ...
...                                                                   ...
20123.332682821  0         0           0            0              0  ...
20123.342682821  0         0           0            0              0  ...
20123.352682821  0         0           0            0              0  ...
20123.362682821  0         0           0            0              0  ...
20123.372682821  0         0           0            0              0  ...
20123.382682821  0         0           0            0              0  ...
20123.392682821  0         0           0            0              0  ...
dtype: float64, shape: (19237, 10)
  • If the basis is composite (for example, the addition of two 1D bases), the transformer will expect a shape of (n_sampels, 1) each 1D component. If that’s not the case, you need to call set_input_shape:

If the basis has more than one component (for example, if it is the addition of two 1D bases), the transformer will expect an input shape of (n_sampels, 1) pre component. If that’s not the case, you’ll provide a different input shape by calling set_input_shape.

Option 1) One input per component:

# generate a composite basis
basis_2d = nmo.basis.MSplineEval(5) +  nmo.basis.MSplineEval(5)
basis_2d = basis_2d.to_transformer()

# this will work: 1 input per component
x, y = np.random.randn(10, 1), np.random.randn(10, 1)
X = np.concatenate([x, y], axis=1)
result = basis_2d.transform(X)

Option 2) Multiple input per component.

  • Then you can call transform on the 2d input as expected.

# Assume 2 input for the first component and 3 for the second.
x, y = np.random.randn(10, 2), np.random.randn(10, 3)
X = np.concatenate([x, y], axis=1)
try:
    basis_2d.transform(X)
except Exception as e:
    print("Exception Raised:")
    print(repr(e))

# Set the expected input shape instead.

# array
res1 = basis_2d.set_input_shape(x, y).transform(X)
# int
res2 = basis_2d.set_input_shape(2, 3).transform(X)
# tuple
res3 = basis_2d.set_input_shape((2,), (3,)).transform(X)
Exception Raised:
ValueError('Input mismatch: expected 2 inputs, but got 5 columns in X.\nTo modify the required number of inputs, call `set_input_shape` before using `fit` or `fit_transform`.')
  • You can, equivalently, call compute_features before turning the basis into a transformer. Then we cache the shape for future use:

position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position")
position_basis.compute_features(position)
position_basis = position_basis.to_transformer()
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed").to_transformer().set_input_shape(1)
basis = position_basis + speed_basis
basis
'(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Let’s create a single TsdFrame to hold all our inputs:

  • Create a single TsdFrame to hold all our inputs:

transformer_input = nap.TsdFrame(
    t=position.t,
    d=np.stack([position, speed], 1),
    time_support=position.time_support,
    columns=["position", "speed"],
)
  • Pass this input to our transformed additive basis:

Our new additive transformer basis can then take these behavioral inputs and turn them into the model’s design matrix.

basis.transform(transformer_input)
Time (s)                0           1            2            3    4  ...
---------------  --------  ----------  -----------  -----------  ---  -----
18193.603802655  0.162854  0.0063002   5.33261e-05  1.12053e-07    0  ...
18193.613802655  0.159556  0.00790116  8.5239e-05   2.27854e-07    0  ...
18193.623802655  0.156303  0.00946868  0.000124422  4.04333e-07    0  ...
18193.633802655  0.151106  0.0119478   0.000203448  8.54051e-07    0  ...
18193.643802655  0.145902  0.014398    0.000303689  1.57398e-06    0  ...
18193.653802655  0.141967  0.0162281   0.000394145  2.34623e-06    0  ...
18193.663802655  0.139957  0.0171555   0.000445419  2.8306e-06     0  ...
...                                                                   ...
20123.332682821  0         0           0            0              0  ...
20123.342682821  0         0           0            0              0  ...
20123.352682821  0         0           0            0              0  ...
20123.362682821  0         0           0            0              0  ...
20123.372682821  0         0           0            0              0  ...
20123.382682821  0         0           0            0              0  ...
20123.392682821  0         0           0            0              0  ...
dtype: float64, shape: (19237, 25)

Pipelines#

We need one more step: scikit-learn cross-validation operates on an estimator, like our GLMs. if we want to cross-validate over the basis or its features, we need to combine our transformer basis with the estimator into a single estimator object. Luckily, scikit-learn provides tools for this: pipelines.

Pipelines are objects that accept a series of (0 or more) transformers, culminating in a final estimator. This is defined as a list of tuples, with each tuple containing a human-readable label and the object itself:

  • If we want to cross-validate over the basis, we need more one more step: combining the basis and the GLM into a single scikit-learn estimator.

  • Pipelines to the rescue!

# set the reg strength to the optimal
glm = nmo.glm.PopulationGLM(solver_name="LBFGS", solver_kwargs={"tol": 10**-12})
pipe = pipeline.Pipeline([
    ("basis", basis),
    ("glm", glm)
])
pipe
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This pipeline object allows us to e.g., call fit using the initial input:

  • Pipeline runs basis.transform, then passes that output to glm, so we can do everything in a single line:

pipe.fit(transformer_input, count)
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(feature_mask=Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32), inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We then visualize the predictions the same as before, using pipe instead of glm.

  • Visualize model predictions!

visualize_model_predictions(pipe, transformer_input)
../../_images/2a01e70c766ac7fe9fc2500758f06849fa9661315b2fb51a77082381c59faaad.png

Cross-validating on the basis#

Now that we have our pipeline estimator, we can cross-validate on any of its parameters!

pipe.steps
[('basis',
  Transformer('(position + speed)': AdditiveBasis(
      basis1='position': MSplineEval(n_basis_funcs=10, order=4),
      basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
  ))),
 ('glm',
  PopulationGLM(
      observation_model=PoissonObservations(),
      inverse_link_function=exp,
      regularizer=UnRegularized(),
      solver_name='LBFGS',
      solver_kwargs={'tol': 1e-12}
  ))]

Let’s cross-validate on the number of basis functions for the position basis, and the identity of the basis for the speed. That is:

Let’s cross-validate on:

  • The number of the basis functions of the position basis

  • The functional form of the basis for speed

print(pipe["basis"]["position"].n_basis_funcs)
print(pipe["basis"]["speed"])
10
Transformer('speed': MSplineEval(n_basis_funcs=15, order=4))

For scikit-learn parameter grids, we use __ to stand in for .:

  • Construct param_grid, using __ to stand in for .

  • In sklearn pipelines, we access nested parameters using double underscores:

    • pipe["basis"]["position"].n_basis_funcs ← normal Python syntax

    • "basis__position__n_basis_funcs" ← sklearn parameter grid syntax

param_grid = {
    "basis__position__n_basis_funcs": [5, 10, 20],
    "basis__speed": [nmo.basis.MSplineEval(15).set_input_shape(1),
                      nmo.basis.BSplineEval(15).set_input_shape(1),
                      nmo.basis.RaisedCosineLinearEval(15).set_input_shape(1)],
}
  • Cross-validate as before:

cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('basis',
                                        Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                                       ('glm',
                                        PopulationGLM(feature_mask=Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]...
       [1., 1., 1.]], dtype=float32), inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))]),
             param_grid={'basis__position__n_basis_funcs': [5, 10, 20],
                         'basis__speed': [MSplineEval(n_basis_funcs=15),
                                          BSplineEval(n_basis_funcs=15),
                                          RaisedCosineLinearEval(n_basis_funcs=15)]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
  • Investigate results:

pd.DataFrame(cv.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_basis__position__n_basis_funcs param_basis__speed params split0_test_score split1_test_score split2_test_score split3_test_score split4_test_score mean_test_score std_test_score rank_test_score
0 1.893693 0.316726 0.183834 0.207669 5 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 5, 'basis__... -0.118457 -0.105659 -0.142501 -0.118152 -0.116768 -0.120307 0.012064 3
1 1.622862 0.072675 0.015925 0.000366 5 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 5, 'basis__... -0.123633 -0.106470 -0.147680 -0.119568 -0.121160 -0.123702 0.013382 8
2 1.647547 0.101936 0.012583 0.000576 5 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 5, 'basis__... -0.123348 -0.105752 -0.152880 -0.119239 -0.120581 -0.124360 0.015500 9
3 1.846706 0.276718 0.017661 0.002708 10 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 10, 'basis_... -0.113527 -0.104531 -0.141387 -0.117203 -0.115781 -0.118486 0.012270 1
4 1.661602 0.077783 0.018340 0.002006 10 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 10, 'basis_... -0.115269 -0.104268 -0.144509 -0.116874 -0.121011 -0.120386 0.013271 4
5 1.697898 0.090093 0.015018 0.001061 10 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 10, 'basis_... -0.117323 -0.104021 -0.147815 -0.116785 -0.121122 -0.121413 0.014407 6
6 2.078325 0.411285 0.114192 0.117116 20 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 20, 'basis_... -0.113746 -0.104792 -0.142771 -0.118216 -0.115467 -0.118998 0.012709 2
7 1.754291 0.077420 0.024533 0.006573 20 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 20, 'basis_... -0.115834 -0.104510 -0.146144 -0.118070 -0.120954 -0.121102 0.013707 5
8 1.734065 0.065167 0.016132 0.000732 20 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 20, 'basis_... -0.116829 -0.104060 -0.148182 -0.117726 -0.120424 -0.121444 0.014514 7

scikit-learn does not cache every model that it runs (that could get prohibitively large!), but it does store the best estimator, as the appropriately-named best_estimator_.

  • Can easily grab the best estimator, the pipeline that did the best:

best_estim = cv.best_estimator_
best_estim
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(feature_mask=Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32), inverse_link_function=<PjitFunction of <function exp at 0x7efda86f6fc0>>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We then visualize the predictions of best_estim the same as before.

  • Visualize model predictions!

visualize_model_predictions(best_estim, transformer_input)
../../_images/2a01e70c766ac7fe9fc2500758f06849fa9661315b2fb51a77082381c59faaad.png

Feature selection#

Now, finally, we understand almost enough about how scikit-learn works to figure out whether both position and speed are necessary inputs, i.e., to do feature selection.

What we would like to do here is comparing alternative models: position + speed, position only or speed only. However, scikit-learn’s cross-validation assumes that the input to the pipeline does not change, only the hyperparameters do. So, how do we go about model selection since we require different input for different model we want to compare?

Here is a neat NeMoS trick to circumvent that. scikit-learn’s GridSearchCV assumes the INPUT stays the same across all models, but for feature selection, we want to compare models with different features (position + speed, position only, speed only). The solution: create a “null” basis that produces zero features, so all models take the same 2D input (position, speed) but some features become empty. First we need to define this “null” basis taking advantage of CustomBasis, which defines a basis from a list of functions.

# this function creates an empty array (n_sample, 0)
def func(x):
    return np.zeros((x.shape[0], 0))

# Create a null basis using the custom basis class
null_basis = nmo.basis.CustomBasis([func]).to_transformer()

# this creates an empty feature
null_basis.compute_features(position).shape
(19237, 0)

Why is this useful? Because we can use this null_basis and basis composition to do model selection.

# first we note that the position + speed basis is in the basis attribute
print(pipe["basis"].basis)

position_bas = nmo.basis.MSplineEval(n_basis_funcs=10).to_transformer()
speed_bas = nmo.basis.MSplineEval(n_basis_funcs=15).to_transformer()

# define 2D basis per each model 
basis_all = position_bas + speed_bas
basis_position = position_bas + null_basis
basis_speed = null_basis + speed_bas

# assign label (not necessary but nice)
basis_all.label = "position + speed"
basis_position.label = "position"
basis_speed.label = "speed"


# then we create a parameter grid defining a grid of 2D basis for each model of interest
param_grid = {
    "basis__basis": 
    [
        basis_all,  
        basis_position, 
        basis_speed 
    ],
}

# finally we define and fit our CV
cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
'(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('basis',
                                        Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                                       ('glm',
                                        PopulationGLM(feature_mask=Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]...
                                          Transformer('position': AdditiveBasis(
    basis1=MSplineEval(n_basis_funcs=10, order=4),
    basis2=CustomBasis(
        funcs=[func],
        ndim_input=1,
        output_shape=(0,),
        pynapple_support=True,
        is_complex=False
    ),
)),
                                          Transformer('speed': AdditiveBasis(
    basis1=CustomBasis(
        funcs=[func],
        ndim_input=1,
        output_shape=(0,),
        pynapple_support=True,
        is_complex=False
    ),
    basis2=MSplineEval(n_basis_funcs=15, order=4),
))]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
cv_df = pd.DataFrame(cv.cv_results_)

# let's just plot a minimal subset of cols
cv_df[["param_basis__basis", "mean_test_score", "rank_test_score"]]
param_basis__basis mean_test_score rank_test_score
0 Transformer('position + speed': AdditiveBasis(... -0.118486 1
1 Transformer('position': AdditiveBasis(\n ba... -0.119011 2
2 Transformer('speed': AdditiveBasis(\n basis... -0.139847 3

Now use all the tools that you learned to find a better model for this dataset!

Conclusion#

Various combinations of features can lead to different results. Feel free to explore more. To go beyond this notebook, you can check the following references :

References#