Hide code cell source

%load_ext autoreload
%autoreload 2

%matplotlib inline
import warnings

warnings.filterwarnings(
    "ignore",
    message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
    category=UserWarning,
)

warnings.filterwarnings(
    "ignore",
    message="Ignoring cached namespace 'core'",
    category=UserWarning,
)

warnings.filterwarnings(
    "ignore",
    message=(
        "invalid value encountered in div "
    ),
    category=RuntimeWarning,
)

Download

This notebook can be downloaded as 03_nemos_advanced.ipynb. See the button at the top right to download as markdown or pdf.

Jupyter Lab tip

Newer versions of Jupyter Lab have addressed an issue with skipping around the notebook while scrolling. To make sure this fix is enabled, in the Jupyter Lab GUI, navigate to Settings > Settings Editor > Notebook and scroll down to the Windowing mode setting and make sure it is set to contentVisibility.

Also reminder to presenter: Go to View > Appearance, select Simple Interface and turn off everything else to hide as many bars as possible. And maybe activate Presentation Mode.

And turn on View > Render side-by-side (shortcut Shift+R).

NeMoS Advanced: Cross-Validation and Model Selection#

Learning Objectives#

In this tutorial we will keep working on the hippocampal place field recordings with the goal of learning how to combine NeMoS and scikit-learn to perform cross-validation and model selection. In particular we will:

  • Learn how to use NeMoS objects with scikit-learn for cross-validation

  • Learn how to use NeMoS objects with scikit-learn pipelines

  • Learn how to use cross-validation to perform model and feature selection. More specifically, we will compare models including position and speed as predictors with model including only speed or only position.

Pre-Processing#

Let’s first load and wrangle the data with pynapple and NeMoS. You can run the following cells for preparing the variables that we are going to use in the notebook and recapitulate the content of this dataset with a few visualizations.

import workshop_utils
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap

import nemos as nmo

# some helper plotting functions
from nemos import _documentation_utils as doc_plots

# configure plots some
plt.style.use(nmo.styles.plot_style)

import workshop_utils

from sklearn import model_selection
from sklearn import pipeline

# shut down jax to numpy conversion warning
nap.nap_config.suppress_conversion_warnings = True
WARNING:2026-02-05 17:43:51,900:jax._src.xla_bridge:876: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
- Load the data using pynapple.
path = workshop_utils.fetch_data("Achilles_10252013_EEG.nwb")
data = nap.load_file(path)
data
Achilles_10252013_EEG
┍━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys        │ Type        │
┝━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units       │ TsGroup     │
│ rem         │ IntervalSet │
│ nrem        │ IntervalSet │
│ forward_ep  │ IntervalSet │
│ eeg         │ TsdFrame    │
│ theta_phase │ Tsd         │
│ position    │ Tsd         │
┕━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙
- Extract the spike times and mouse position.
spikes = data["units"]
position = data["position"]

For today, we’re only going to focus on the times when the animal was traversing the linear track. This is a pynapple IntervalSet, so we can use it to restrict our other variables:

  • Restrict data to when animal was traversing the linear track.

position = position.restrict(data["forward_ep"])
spikes = spikes.restrict(data["forward_ep"])

The recording contains both inhibitory and excitatory neurons. Here we will focus of the excitatory cells with firing above 0.3 Hz.

  • Restrict neurons to only excitatory neurons, discarding neurons with a low-firing rate.

spikes = spikes.getby_category("cell_type")["pE"]
spikes = spikes.getby_threshold("rate", 0.3)

Place fields#

By plotting the neuronal firing rate as a function of position, we can see that these neurons are all tuned for position: they fire in a specific location on the track.

  • Visualize the place fields: neuronal firing rate as a function of position.

place_fields = nap.compute_tuning_curves(spikes, position, bins=50, epochs=position.time_support, feature_names=["distance"])
workshop_utils.plot_place_fields(place_fields)
../../_images/1e40013987f902a1c112be3b8bdebc606a85b56ca230c336b3a3564099e9e4de.png

To decrease computation time, we’re going to spend the rest of the notebook focusing on the neurons highlighted above. We’re also going to bin spikes at 100 Hz and up-sample the position to match that temporal resolution.

  • For speed, we’re only going to investigate the three neurons highlighted above.

  • Bin spikes to counts at 100 Hz.

  • Interpolate position to match spike resolution.

neurons = [82, 92, 220]
place_fields = place_fields.sel(unit=neurons)
spikes = spikes[neurons]
bin_size = .01
count = spikes.count(bin_size, ep=position.time_support)
position = position.interpolate(count, ep=count.time_support)
print(count.shape)
print(position.shape)
(19237, 3)
(19237,)

Extract Speed per Epoch#

In the next block, we compute the speed of the animal for each epoch (i.e. crossing of the linear track) by taking the temporal derivative of the position. You can use the pynapple method derivative that computes an approximate derivative.

  • Compute the animal’s speed.

  • Visualize tuning curves to speed and position.

speed = position.derivative()
print(speed.shape)

# utility function to visualize predictions
tc_speed = nap.compute_tuning_curves(spikes, speed, bins=20, epochs=speed.time_support, feature_names=["speed"])
fig = workshop_utils.plot_position_speed(position, speed, place_fields, tc_speed, neurons);

def visualize_model_predictions(glm, X):
    # predict the model's firing rate
    predicted_rate = glm.predict(X) / bin_size

    # compute the position and speed tuning curves using the predicted firing rate.
    glm_pos = nap.compute_tuning_curves(predicted_rate, position, bins=50, epochs=position.time_support, feature_names=["position"])
    glm_speed = nap.compute_tuning_curves(predicted_rate, speed, bins=30, epochs=position.time_support, feature_names=["speed"])

    workshop_utils.plot_position_speed_tuning(place_fields, tc_speed, glm_pos, glm_speed);
(19237,)
../../_images/51c62a06678b17a1e032dc77371a84274c840b39799a23e3011189d532a20895.png

Define 1D NeMoS Bases#

  • Define the position and speed bases, and visualize them.

position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position")
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed")
workshop_utils.plot_pos_speed_bases(position_basis, speed_basis);
../../_images/8b8a4aa6bf5fb98001cf1f12124628c394f743bb140bff25e264da38d1465788.png

Basis Composition#

The first new concept we will introduce will be that of basis composition. NeMoS basis can be composed using the “+” (and “*”, see NeMoS docs of more info) operator, to define more complex predictor.

Adding two 1D basis, will result in a 2D additive basis. The compute_features of the additive basis requires 2 inputs, and the output will be the concatenation of the design matrices of the basis components.

  • Adding the position and speed bases together defines a 2D basis.

  • Call compute_features to define a design matrix that concatenates both features.

```{code-cell} ipython3 # add the bases basis = # get the design matrix X = ```
basis = position_basis + speed_basis

X = basis.compute_features(position, speed)
X_numpy = np.concatenate(
    [
        position_basis.compute_features(position),
        speed_basis.compute_features(speed),
    ],
    axis=1
)

print("Are the design matrices equivalent?", np.all(X.d == X_numpy.d))
Are the design matrices equivalent? True

Scikit-learn#

How to know when to regularize?#

In the head direction project, we fit the all-to-all connectivity of the head-tuning dataset using the Ridge regularizer, and we learned that regularization can combat overfitting. What we didn’t show is how to choose a proper regularizer. Generally, too much regularization leads to underfitting, i.e. the model is too simple and doesn’t capture the neural variability well. To little regularization may overfit, especially when we have a large number of parameters, i.e. out model will capture both signal and noise. This is what we saw in the head direction notebook when we used the raw spike history as predictor.

What we are looking for is a regularization strength that balances out the bias towards simpler models with the variance necessary to explain the data. However, how do we know how much we should regularize? One thing we can do is use cross-validation to see whether model performance on unseen data improves with regularization (behind the scenes, this is what we did!). We’ll walk through how to do that now.

Instead of implementing our own cross-validation machinery, the developers of nemos decided that we should write the package to be compliant with scikit-learn, the canonical machine learning python library. Our models are all what scikit-learn calls “estimators”, which means they have .fit, .score. and .predict methods. Thus, we can use them with scikit-learn’s objects out of the box.

We’re going to use scikit-learn’s GridSearchCV object, which performs a cross-validated grid search, as Edoardo explained in his presentation.

This object requires an estimator, our glm object here, and param_grid, a dictionary defining what to check. For now, let’s just compare Ridge regularization at two different strengths:

  • How do we decide when to use regularization?

  • Cross-validation allows you to fairly compare different models on the same dataset.

  • NeMoS makes use of scikit-learn, the standard machine learning library in python.

  • Define parameter grid to search over.

  • Anything not specified in grid will be kept constant.

```{code-cell} ipython3 # configurations of the PopulationGLM solver_kwargs={"tol": 1e-12} solver_name="LBFGS" # define a Ridge regularized PopulationGLM glm = ```
# define a Ridge PopulationGLM
glm = nmo.glm.PopulationGLM(
    regularizer="Ridge",
    solver_kwargs={"tol": 1e-12},
    solver_name="LBFGS",
)
param_grid = {
    "regularizer_strength": [0.0001, 1.],
}
```{code-cell} ipython3 cv_folds = 5 cv = cv ```
cv_folds = 5
cv = model_selection.GridSearchCV(glm, param_grid, cv=cv_folds)
cv
GridSearchCV(cv=5,
             estimator=PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
             param_grid={'regularizer_strength': [0.0001, 1.0]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This will take a bit to run, because we’re fitting the model many times!

  • We interact with this in a very similar way to the glm object.

  • In particular, call fit with same arguments:

cv.fit(X, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
GridSearchCV(cv=5,
             estimator=PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
             param_grid={'regularizer_strength': [0.0001, 1.0]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
  • Let’s investigate results:

Cross-validation results are stored in a dictionary attribute called cv_results_, which contains a lot of info. Let’s convert that to a pandas dataframe for readability,

pd.DataFrame(cv.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_regularizer_strength params split0_test_score split1_test_score split2_test_score split3_test_score split4_test_score mean_test_score std_test_score rank_test_score
0 1.759626 0.687586 0.363913 0.440074 0.0001 {'regularizer_strength': 0.0001} -0.127700 -0.118069 -0.161886 -0.131911 -0.125092 -0.132931 0.015160 1
1 2.978946 2.513113 0.006137 0.001482 1.0000 {'regularizer_strength': 1.0} -0.137128 -0.128441 -0.173226 -0.138692 -0.131281 -0.141753 0.016175 2

The most informative for us is the 'mean_test_score' key, which shows the average of glm.score on each test-fold. Thus, higher is better, and we can see that the model performs better with lower regularization strength.

Find the best regularization strength!

As an exercise, spend 10 minutes trying to find the best regularization strength!

  • You should use the glm model we defined in this section.

  • You will need to redefine the param_grid dictionary, selecting different values for "regularizer_strength":

param_grid = {
    "regularizer_strength": ...,
}
  • After defining param_grid, reinitialize cv (you can do so with the same arguments).

  • Then call cv.fit and re-run pd.DataFrame(cv.cv_results_) to summarize the results.

Who can find the best regularization strength?

If you finish early, try out different regularizers and try to find the best regularization strength for each of them.

Select basis#

We can do something similar to select the basis. In the above example, I just told you which basis function to use and how many of each. But, in general, you want to select those in a reasonable manner. Cross-validation to the rescue!

Unlike the glm objects, our basis objects are not scikit-learn compatible right out of the box. However, they can be made compatible by using the .to_transformer() method (or, equivalently, by using the TransformerBasis class)

  • You can (and should) do something similar to determine how many basis functions you need for each input.

  • NeMoS basis objects are not scikit-learn-compatible right out of the box.

  • But we have provided a simple method to make them so:

```{code-cell} ipython3 # convert basis to transformer position_basis = position_basis ```
position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position").to_transformer()
# or equivalently:
position_basis = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(n_basis_funcs=10, label="position"))
position_basis
'position': MSplineEval(n_basis_funcs=10, order=4)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This gives the basis object the transform method, which is equivalent to compute_features. However, transformers have some limits:

  • This gives the basis object the transform method, which is equivalent to compute_features.

  • However, transformers have some limits:

position_basis.transform(position)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[17], line 1
----> 1 position_basis.transform(position)

File ~/.venv/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:247, in TransformerBasis.transform(self, X, y)
    210 """
    211 Transform the data using the fitted basis functions.
    212 
   (...)    244 >>> feature_transformed = transformer.transform(X)
    245 """
    246 self._check_initialized(self.basis)
--> 247 self._check_input(X, y)
    248 # transpose does not work with pynapple
    249 # can't use func(*X.T) to unwrap
    250 return self.basis.compute_features(*self._unpack_inputs(X))

File ~/.venv/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:554, in TransformerBasis._check_input(self, X, y)
    551     raise ValueError("The input must be a 2-dimensional array.")
    553 elif ndim != 2:
--> 554     raise ValueError(
    555         f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead."
    556     )
    558 if X.shape[1] != sum(self._input_shape_product):
    559     raise ValueError(
    560         f"Input mismatch: expected {sum(self._input_shape_product)} inputs, but got {X.shape[1]} "
    561         f"columns in X.\nTo modify the required number of inputs, call `set_input_shape` before using "
    562         f"`fit` or `fit_transform`."
    563     )

ValueError: X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape (19237,) instead.
  • Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.

  • In order to use a basis as a transformer, you’ll need to concatenate all your input in a single 2D array.

Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.

position_basis.transform(position[:, np.newaxis])
Time (s)                0           1            2            3    4  ...
---------------  --------  ----------  -----------  -----------  ---  -----
18193.603802655  0.162854  0.0063002   5.33261e-05  1.12053e-07    0  ...
18193.613802655  0.159556  0.00790116  8.52389e-05  2.27854e-07    0  ...
18193.623802655  0.156303  0.00946868  0.000124421  4.04332e-07    0  ...
18193.633802655  0.151106  0.0119478   0.000203448  8.54051e-07    0  ...
18193.643802655  0.145902  0.014398    0.000303689  1.57397e-06    0  ...
18193.653802655  0.141967  0.0162281   0.000394145  2.34623e-06    0  ...
18193.663802655  0.139957  0.0171555   0.000445419  2.8306e-06     0  ...
...                                                                   ...
20123.332682821  0         0           0            0              0  ...
20123.342682821  0         0           0            0              0  ...
20123.352682821  0         0           0            0              0  ...
20123.362682821  0         0           0            0              0  ...
20123.372682821  0         0           0            0              0  ...
20123.382682821  0         0           0            0              0  ...
20123.392682821  0         0           0            0              0  ...
dtype: float32, shape: (19237, 10)
Other Caveats

If the basis has more than one component (for example, if it is the addition of two 1D bases), the transformer will expect an input shape of (n_sampels, 1) pre component. If that’s not the case, you’ll provide a different input shape by calling set_input_shape.

Case 1) One input per component:

# generate a composite basis
basis_2d = nmo.basis.MSplineEval(5) + nmo.basis.MSplineEval(5)
basis_2d = basis_2d.to_transformer()

# this will work: 1 input per component
x, y = np.random.randn(10, 1), np.random.randn(10, 1)
X = np.concatenate([x, y], axis=1)
result = basis_2d.transform(X)

Case 2) Multiple inputs per component.

  • If one or more basis process multiple inputs (multiple columns of the 2D array), trying to call the transform method directly will lead to an error.

  • This is because the basis doesn’t know which component should process which column.

# Assume 2 input for the first component and 3 for the second.
x, y = np.random.randn(10, 2), np.random.randn(10, 3)
X = np.concatenate([x, y], axis=1)

res = basis_2d.transform(X)  # This will raise an exception!

To prevent that, use set_input_shape to define how many inputs each component should process.

# Set the expected input shape instead, different options:

# array
res1 = basis_2d.set_input_shape(x, y).transform(X)
# int
res2 = basis_2d.set_input_shape(2, 3).transform(X)
# tuple
res3 = basis_2d.set_input_shape((2,), (3,)).transform(X)
  • Let’s now create the composite basis for speed and position.

```{code-cell} ipython3 # redefine the basis with label="position" position_basis = # redefine the basis with label="speed" speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed") basis = position_basis + speed_basis # convert to transformer basis = basis ```
position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position")
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed")
basis = position_basis + speed_basis
basis = basis.to_transformer()
basis
'(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Let’s create a single TsdFrame to hold all our inputs:

  • Stack position and speed in a single TsdFrame to hold all our inputs:

transformer_input = nap.TsdFrame(
    t=position.t,
    d=np.stack([position, speed]).T,
    time_support=position.time_support,
    columns=["position", "speed"],
)
  • Pass this input to our transformed additive basis.

Our new additive transformer basis can then take these behavioral inputs and turn them into the model’s design matrix.

basis.transform(transformer_input)
Time (s)                0           1            2            3    4  ...
---------------  --------  ----------  -----------  -----------  ---  -----
18193.603802655  0.162854  0.0063002   5.33261e-05  1.12053e-07    0  ...
18193.613802655  0.159556  0.00790116  8.52389e-05  2.27854e-07    0  ...
18193.623802655  0.156303  0.00946868  0.000124421  4.04332e-07    0  ...
18193.633802655  0.151106  0.0119478   0.000203448  8.54051e-07    0  ...
18193.643802655  0.145902  0.014398    0.000303689  1.57397e-06    0  ...
18193.653802655  0.141967  0.0162281   0.000394145  2.34623e-06    0  ...
18193.663802655  0.139957  0.0171555   0.000445419  2.8306e-06     0  ...
...                                                                   ...
20123.332682821  0         0           0            0              0  ...
20123.342682821  0         0           0            0              0  ...
20123.352682821  0         0           0            0              0  ...
20123.362682821  0         0           0            0              0  ...
20123.372682821  0         0           0            0              0  ...
20123.382682821  0         0           0            0              0  ...
20123.392682821  0         0           0            0              0  ...
dtype: float32, shape: (19237, 25)

Pipelines#

We need one more step: scikit-learn cross-validation operates on an estimator, like our GLMs. if we want to cross-validate over the basis or its features, we need to combine our transformer basis with the estimator into a single estimator object. Luckily, scikit-learn provides tools for this: pipelines.

Pipelines are objects that accept a series of (0 or more) transformers, culminating in a final estimator. This is defined as a list of tuples, with each tuple containing a human-readable label and the object itself:

  • If we want to cross-validate over the basis, we need more one more step: combining the basis and the GLM into a single scikit-learn estimator.

  • Pipelines to the rescue!

```{code-cell} ipython3 # set the reg strength to the optimal glm = # pipe the basis and the glm pipe = pipeline.Pipeline( pipe ```
# set the reg strength to the optimal
glm = nmo.glm.PopulationGLM(solver_name="LBFGS", solver_kwargs={"tol": 10**-12})
pipe = pipeline.Pipeline([
    ("basis", basis),
    ("glm", glm)
])
pipe
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

This pipeline object allows us to e.g., call fit using the initial input:

  • Pipeline runs basis.transform, then passes that output to glm, so we can do everything in a single line:

pipe.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We then visualize the predictions the same as before, using pipe instead of glm.

  • Visualize model predictions!

visualize_model_predictions(pipe, transformer_input)
../../_images/211523afe1205dcc08f7a1ac61ce4cd7944e362d6e5772142ce7c4c6c034b9b7.png

Cross-validating on the basis#

Now that we have our pipeline estimator, we can cross-validate on any of its parameters!

pipe.steps
[('basis',
  Transformer('(position + speed)': AdditiveBasis(
      basis1='position': MSplineEval(n_basis_funcs=10, order=4),
      basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
  ))),
 ('glm',
  PopulationGLM(
      observation_model=PoissonObservations(),
      inverse_link_function=exp,
      regularizer=UnRegularized(),
      solver_name='LBFGS',
      solver_kwargs={'tol': 1e-12}
  ))]

Let’s cross-validate on the number of basis functions for the position basis, and the identity of the basis for the speed. That is:

Let’s cross-validate on:

  • The number of the basis functions of the position basis

  • The functional form of the basis for speed

  • Let’s retrieve the those attributes from the pipeline

# the label of the pipeline step retrieves the basis
print(pipe["basis"])

# the position basis can by retreived by its label
print("\n", pipe["basis"]["position"])

# the n_basis_funcs is an attribute
print("\n", pipe["basis"]["position"].n_basis_funcs)

# with the same syntax we can retreive the speed basis
print("\n", pipe["basis"]["speed"])
Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))

 Transformer('position': MSplineEval(n_basis_funcs=10, order=4))

 10

 Transformer('speed': MSplineEval(n_basis_funcs=15, order=4))

For scikit-learn parameter grids, we use __ to stand in for .:

  • Construct param_grid, using __ to stand in for .

  • In scikit-learn pipelines, we access nested parameters using double underscores:

    • pipe["basis"]["position"].n_basis_funcs - normal Python syntax

    • "basis__position__n_basis_funcs" - scikit-learn parameter grid syntax

```{code-cell} ipython3 param_grid = ```
param_grid = {
    "basis__position__n_basis_funcs": [5, 10, 20],
    "basis__speed": [nmo.basis.MSplineEval(15),
                      nmo.basis.BSplineEval(15),
                      nmo.basis.RaisedCosineLinearEval(15)],
}
  • Cross-validate as before:

```{code-cell} ipython3 # define the grid search and fit cv = ```
cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('basis',
                                        Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                                       ('glm',
                                        PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))]),
             param_grid={'basis__position__n_basis_funcs': [5, 10, 20],
                         'basis__speed': [MSplineEval(n_basis_funcs=15),
                                          BSplineEval(n_basis_funcs=15),
                                          RaisedCosineLinearEval(n_basis_funcs=15)]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
  • Investigate results:

pd.DataFrame(cv.cv_results_)
mean_fit_time std_fit_time mean_score_time std_score_time param_basis__position__n_basis_funcs param_basis__speed params split0_test_score split1_test_score split2_test_score split3_test_score split4_test_score mean_test_score std_test_score rank_test_score
0 1.922558 0.666500 0.527833 0.528948 5 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 5, 'basis__... -0.136246 -0.116545 -0.147549 -0.122671 -0.128592 -0.130321 0.010800 3
1 1.495269 0.247778 0.037604 0.004972 5 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 5, 'basis__... -0.150079 -0.155905 -0.172195 -0.126545 -0.198609 -0.160667 0.023965 8
2 1.495143 0.226423 0.144760 0.129065 5 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 5, 'basis__... -0.163750 -0.128224 -0.184139 -0.125410 -0.216403 -0.163585 0.034409 9
3 1.619937 0.185751 0.145199 0.046221 10 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 10, 'basis_... -0.127216 -0.113708 -0.148560 -0.122756 -0.132335 -0.128915 0.011573 2
4 1.626782 0.260083 0.055368 0.008168 10 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 10, 'basis_... -0.153041 -0.131869 -0.165941 -0.123137 -0.176732 -0.150144 0.020141 6
5 1.539392 0.141097 0.057420 0.009103 10 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 10, 'basis_... -0.148822 -0.126507 -0.168947 -0.122744 -0.168339 -0.147072 0.019742 4
6 1.914416 0.385908 0.283016 0.193542 20 (((MSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 20, 'basis_... -0.126687 -0.112765 -0.146401 -0.124083 -0.123058 -0.126599 0.010976 1
7 1.756607 0.184487 0.079363 0.019208 20 (((BSplineEval(n_basis_funcs=15, order=4)))) {'basis__position__n_basis_funcs': 20, 'basis_... -0.153473 -0.132539 -0.167847 -0.124594 -0.179855 -0.151662 0.020777 7
8 1.638160 0.157081 0.102488 0.014956 20 (((RaisedCosineLinearEval(n_basis_funcs=15, wi... {'basis__position__n_basis_funcs': 20, 'basis_... -0.146942 -0.128252 -0.172728 -0.124165 -0.167277 -0.147873 0.019709 5

scikit-learn does not cache every model that it runs (that could get prohibitively large!), but it does store the best estimator, as the appropriately-named best_estimator_.

  • Can easily grab the best estimator, the pipeline that did the best:

```{code-cell} ipython3 # define the grid search and fit best_estim = best_estim ```
best_estim = cv.best_estimator_
best_estim
Pipeline(steps=[('basis',
                 Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=20, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                ('glm',
                 PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

We then visualize the predictions of best_estim the same as before.

  • Visualize model predictions!

visualize_model_predictions(best_estim, transformer_input)
../../_images/b23c70b7963378df985447356746d5739822cfa80a9b4888b86f26fbb5a2aa10.png

Find the best basis!

As an exercise, spend 10 minutes exploring the possible basis objects and seeing which performs the best.

  • You should use the pipe object we defined in this section.

  • You will need to redefine the param_grid dictionary, setting basis__speed and basis__position (or their attributes, e.g., basis__position__n_basis_funcs) to a range of values. Remember that all combinations are tested, so if you e.g., select 5 choices for each, you’ll be testing 25 different combinations!

param_grid = {
    "basis__position": ...,
    "basis__speed": ...,
}
  • After defining param_grid, reinitialize cv (you can do so with the same arguments).

  • Then call cv.fit and re-run pd.DataFrame(cv.cv_results_) to summarize the results.

  • Finally, visualize the best estimator with visualize_model_predictions(cv.best_estimator_, transformer_input)

Who can find the best set of basis objects?

Feature selection#

Now that we understand how scikit-learn works with NeMoS, we can determine whether both position and speed are necessary inputs by performing feature selection.

Our goal is to compare alternative models: position + speed, position only, or speed only. However, scikit-learn’s cross-validation assumes that the input to the pipeline stays constant—only the hyperparameters change. So how can we compare models that require different features?

Here’s a clever NeMoS trick: we’ll create a “null” basis that produces zero features. This way, all models take the same 2-D input (position, speed), but some features become empty arrays. We can define this null basis using CustomBasis, which creates a basis from a list of functions.

Let’s move on to feature selection. Our goal is to compare alternative models: position + speed, position only, or speed only.

Problem: scikit-learn’s cross-validation assumes the pipeline input stays constant, but each model needs different features. How do we solve this?

Solution: Use a “null” basis that produces zero features!

  • We’ll create this null basis using CustomBasis, which defines a basis from custom functions.

```{code-cell} ipython3 # define a function that creates an empty array (n_samples, 0) def func(x): return np.zeros((x.shape[0], 0)) # create a null transformer basis using the custom basis class null_basis = # verify: this creates an empty feature array null_basis.compute_features(position).shape ```
# define a function that creates an empty array (n_samples, 0)
def func(x):
    return np.zeros((x.shape[0], 0))

# create a null transformer basis using the custom basis class
null_basis = nmo.basis.CustomBasis([func]).to_transformer()

# verify: this creates an empty feature array
null_basis.compute_features(position).shape
(19237, 0)

Why is this useful? We can combine null_basis with actual bases to create different models that all accept the same input!

Let’s define the bases for our three models:

  • Position + speed: combine position and speed bases

  • Position only: combine position basis with null basis (speed features is empty)

  • Speed only: combine null basis with speed basis (position features is empty)

```{code-cell} ipython3 # combine them to define each model basis_all = basis_position = basis_speed = # assign labels (optional but helpful for readability) basis_all.label = "position + speed" basis_position.label = "position only" basis_speed.label = "speed only" ```
# combine them to define each model
basis_all = (position_basis + speed_basis).to_transformer()
basis_position = (position_basis + null_basis).to_transformer()
basis_speed = (null_basis + speed_basis).to_transformer()

# assign labels (optional but helpful for readability)
basis_all.label = "position + speed"
basis_position.label = "position only"
basis_speed.label = "speed only"

These bases can all transform the same transformer_input (a TsdFrame with columns for position and speed), but they generate design matrices with different numbers of features:

# "position + speed" design: 25 features (10 + 15)
print("position + speed design matrix shape:")
print(basis_all.transform(transformer_input).shape)

# "position" design: 10 features (10 + 0)
print("\nposition design matrix shape:")
print(basis_position.transform(transformer_input).shape)

# "speed" design: 15 features (0 + 15)
print("\nspeed design matrix shape:")
print(basis_speed.transform(transformer_input).shape)
position + speed design matrix shape:
(19237, 25)

position design matrix shape:
(19237, 10)

speed design matrix shape:
(19237, 15)

To cross-validate over different basis compositions, we need to understand how they’re stored in our pipeline. The additive basis is stored as a basis attribute inside the TransformerBasis object:

# the "basis" step in our pipeline contains a TransformerBasis
# which has a "basis" attribute storing the actual additive basis
pipe["basis"].basis
'(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Now we can create a parameter grid for cross-validation. The key is the string "basis__basis":

  • First basis: the name of the pipeline step

  • Second basis: the attribute of the TransformerBasis object

  • This double-underscore notation is how scikit-learn accesses nested parameters

```{code-cell} ipython3 # create parameter grid with our three basis compositions param_grid = ```
# create parameter grid with our three basis compositions
param_grid = {
    "basis__basis": [
        basis_all,      # position + speed
        basis_position, # position only
        basis_speed     # speed only
    ],
}
```{code-cell} ipython3 # define and fit GridSearchCV cv = ```
# define and fit GridSearchCV
cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)`` 
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
  warnings.warn(
GridSearchCV(cv=5,
             estimator=Pipeline(steps=[('basis',
                                        Transformer('(position + speed)': AdditiveBasis(
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
                                       ('glm',
                                        PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(...
    basis1='position': MSplineEval(n_basis_funcs=10, order=4),
    basis2=CustomBasis(
        funcs=[func],
        ndim_input=1,
        output_shape=(0,),
        pynapple_support=True,
        is_complex=False
    ),
)),
                                          Transformer('speed only': AdditiveBasis(
    basis1=CustomBasis(
        funcs=[func],
        ndim_input=1,
        output_shape=(0,),
        pynapple_support=True,
        is_complex=False
    ),
    basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))]})
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Let's examine the model comparison results:
cv_df = pd.DataFrame(cv.cv_results_)

# display the key columns: which basis was used, its score, and ranking
cv_df[["param_basis__basis", "mean_test_score", "rank_test_score"]]
param_basis__basis mean_test_score rank_test_score
0 Transformer('position + speed': AdditiveBasis(... -0.128915 2
1 Transformer('position only': AdditiveBasis(\n ... -0.118960 1
2 Transformer('speed only': AdditiveBasis(\n ... -0.175172 3
Position emerges as the predictor with the greatest explanatory power, while speed adds only marginal benefits.

Find the model!

In this section, we only compared a single choice of regularization strength and basis objects for each feature. As an exercise, spend 10 minutes combining what we learned here with the earlier sections: for each feature combination (position, speed, position + speed), try several different basis objects and, optionally, different regularization strengths.

Don’t forget to visualize your model’s predictions!

Who can find the best model?

Next Steps#

For the next project, you can use all the tools showcased here to find a better encoding model for these hippocampal neurons.

Suggestions:

  • Extend the model by including theta phase as a predictor

  • Use the NeMoS MultiplicativeBasis to capture interactions between theta phase and position

References#

The data in this tutorial comes from [Grosmark, Andres D., and György Buzsáki. "Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences." Science 351.6280 (2016): 1440-1443](https://www.science.org/doi/full/10.1126/science.aad1935).