Download
This notebook can be downloaded as 03_nemos_advanced.ipynb. See the button at the top right to download as markdown or pdf.
Jupyter Lab tip
Newer versions of Jupyter Lab have addressed an issue with skipping around the notebook while scrolling. To make sure this fix is enabled, in the Jupyter Lab GUI, navigate to Settings > Settings Editor > Notebook and scroll down to the Windowing mode setting and make sure it is set to contentVisibility.
Also reminder to presenter: Go to View > Appearance, select Simple Interface and turn off everything else to hide as many bars as possible. And maybe activate Presentation Mode.
And turn on View > Render side-by-side (shortcut Shift+R).
NeMoS Advanced: Cross-Validation and Model Selection#
Learning Objectives#
In this tutorial we will keep working on the hippocampal place field recordings with the goal of learning how to combine NeMoS and scikit-learn to perform cross-validation and model selection. In particular we will:
Learn how to use NeMoS objects with scikit-learn for cross-validation
Learn how to use NeMoS objects with scikit-learn pipelines
Learn how to use cross-validation to perform model and feature selection. More specifically, we will compare models including position and speed as predictors with model including only speed or only position.
Pre-Processing#
Let’s first load and wrangle the data with pynapple and NeMoS. You can run the following cells for preparing the variables that we are going to use in the notebook and recapitulate the content of this dataset with a few visualizations.
import workshop_utils
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap
import nemos as nmo
# some helper plotting functions
from nemos import _documentation_utils as doc_plots
# configure plots some
plt.style.use(nmo.styles.plot_style)
import workshop_utils
from sklearn import model_selection
from sklearn import pipeline
# shut down jax to numpy conversion warning
nap.nap_config.suppress_conversion_warnings = True
WARNING:2026-02-05 17:43:51,900:jax._src.xla_bridge:876: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
path = workshop_utils.fetch_data("Achilles_10252013_EEG.nwb")
data = nap.load_file(path)
data
Achilles_10252013_EEG
┍━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys │ Type │
┝━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units │ TsGroup │
│ rem │ IntervalSet │
│ nrem │ IntervalSet │
│ forward_ep │ IntervalSet │
│ eeg │ TsdFrame │
│ theta_phase │ Tsd │
│ position │ Tsd │
┕━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙
spikes = data["units"]
position = data["position"]
For today, we’re only going to focus on the times when the animal was traversing the linear track.
This is a pynapple IntervalSet, so we can use it to restrict our other variables:
Restrict data to when animal was traversing the linear track.
position = position.restrict(data["forward_ep"])
spikes = spikes.restrict(data["forward_ep"])
The recording contains both inhibitory and excitatory neurons. Here we will focus of the excitatory cells with firing above 0.3 Hz.
Restrict neurons to only excitatory neurons, discarding neurons with a low-firing rate.
spikes = spikes.getby_category("cell_type")["pE"]
spikes = spikes.getby_threshold("rate", 0.3)
Place fields#
By plotting the neuronal firing rate as a function of position, we can see that these neurons are all tuned for position: they fire in a specific location on the track.
Visualize the place fields: neuronal firing rate as a function of position.
place_fields = nap.compute_tuning_curves(spikes, position, bins=50, epochs=position.time_support, feature_names=["distance"])
workshop_utils.plot_place_fields(place_fields)
To decrease computation time, we’re going to spend the rest of the notebook focusing on the neurons highlighted above. We’re also going to bin spikes at 100 Hz and up-sample the position to match that temporal resolution.
For speed, we’re only going to investigate the three neurons highlighted above.
Bin spikes to counts at 100 Hz.
Interpolate position to match spike resolution.
neurons = [82, 92, 220]
place_fields = place_fields.sel(unit=neurons)
spikes = spikes[neurons]
bin_size = .01
count = spikes.count(bin_size, ep=position.time_support)
position = position.interpolate(count, ep=count.time_support)
print(count.shape)
print(position.shape)
(19237, 3)
(19237,)
Extract Speed per Epoch#
In the next block, we compute the speed of the animal for each epoch (i.e. crossing of the linear track) by taking the temporal derivative of the position. You can use the pynapple method derivative that computes an approximate derivative.
Compute the animal’s speed.
Visualize tuning curves to speed and position.
speed = position.derivative()
print(speed.shape)
# utility function to visualize predictions
tc_speed = nap.compute_tuning_curves(spikes, speed, bins=20, epochs=speed.time_support, feature_names=["speed"])
fig = workshop_utils.plot_position_speed(position, speed, place_fields, tc_speed, neurons);
def visualize_model_predictions(glm, X):
# predict the model's firing rate
predicted_rate = glm.predict(X) / bin_size
# compute the position and speed tuning curves using the predicted firing rate.
glm_pos = nap.compute_tuning_curves(predicted_rate, position, bins=50, epochs=position.time_support, feature_names=["position"])
glm_speed = nap.compute_tuning_curves(predicted_rate, speed, bins=30, epochs=position.time_support, feature_names=["speed"])
workshop_utils.plot_position_speed_tuning(place_fields, tc_speed, glm_pos, glm_speed);
(19237,)
Define 1D NeMoS Bases#
Define the position and speed bases, and visualize them.
position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position")
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed")
workshop_utils.plot_pos_speed_bases(position_basis, speed_basis);
Basis Composition#
The first new concept we will introduce will be that of basis composition. NeMoS basis can be composed using the “+” (and “*”, see NeMoS docs of more info) operator, to define more complex predictor.
Adding two 1D basis, will result in a 2D additive basis. The compute_features of the additive basis requires 2 inputs, and the output will be the concatenation of the design matrices of the basis components.
Adding the position and speed bases together defines a 2D basis.
Call
compute_featuresto define a design matrix that concatenates both features.
basis = position_basis + speed_basis
X = basis.compute_features(position, speed)
X_numpy = np.concatenate(
[
position_basis.compute_features(position),
speed_basis.compute_features(speed),
],
axis=1
)
print("Are the design matrices equivalent?", np.all(X.d == X_numpy.d))
Are the design matrices equivalent? True
Scikit-learn#
How to know when to regularize?#
In the head direction project, we fit the all-to-all connectivity of the head-tuning dataset using the Ridge regularizer, and we learned that regularization can combat overfitting. What we didn’t show is how to choose a proper regularizer. Generally, too much regularization leads to underfitting, i.e. the model is too simple and doesn’t capture the neural variability well. To little regularization may overfit, especially when we have a large number of parameters, i.e. out model will capture both signal and noise. This is what we saw in the head direction notebook when we used the raw spike history as predictor.
What we are looking for is a regularization strength that balances out the bias towards simpler models with the variance necessary to explain the data. However, how do we know how much we should regularize? One thing we can do is use cross-validation to see whether model performance on unseen data improves with regularization (behind the scenes, this is what we did!). We’ll walk through how to do that now.
Instead of implementing our own cross-validation machinery, the developers of nemos decided that we should write the package to be compliant with scikit-learn, the canonical machine learning python library. Our models are all what scikit-learn calls “estimators”, which means they have .fit, .score. and .predict methods. Thus, we can use them with scikit-learn’s objects out of the box.
We’re going to use scikit-learn’s GridSearchCV object, which performs a cross-validated grid search, as Edoardo explained in his presentation.
This object requires an estimator, our glm object here, and param_grid, a dictionary defining what to check. For now, let’s just compare Ridge regularization at two different strengths:
How do we decide when to use regularization?
Cross-validation allows you to fairly compare different models on the same dataset.
NeMoS makes use of scikit-learn, the standard machine learning library in python.
Define parameter grid to search over.
Anything not specified in grid will be kept constant.
# define a Ridge PopulationGLM
glm = nmo.glm.PopulationGLM(
regularizer="Ridge",
solver_kwargs={"tol": 1e-12},
solver_name="LBFGS",
)
param_grid = {
"regularizer_strength": [0.0001, 1.],
}
Initialize scikit-learn’s
model_selection.GridSearchCVobject.
cv_folds = 5
cv = model_selection.GridSearchCV(glm, param_grid, cv=cv_folds)
cv
GridSearchCV(cv=5,
estimator=PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
param_grid={'regularizer_strength': [0.0001, 1.0]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
PopulationGLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength=1.0,
solver_name='LBFGS',
solver_kwargs={'tol': 1e-12}
)Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | Ridge() | |
| regularizer_strength | 1.0 | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
This will take a bit to run, because we’re fitting the model many times!
We interact with this in a very similar way to the glm object.
In particular, call
fitwith same arguments:
cv.fit(X, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
GridSearchCV(cv=5,
estimator=PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=Ridge(), regularizer_strength=1.0, solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'),
param_grid={'regularizer_strength': [0.0001, 1.0]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
PopulationGLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=Ridge(),
regularizer_strength=0.0001,
solver_name='LBFGS',
solver_kwargs={'tol': 1e-12}
)Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | Ridge() | |
| regularizer_strength | 0.0001 | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
Let’s investigate results:
Cross-validation results are stored in a dictionary attribute called cv_results_, which contains a lot of info. Let’s convert that to a pandas dataframe for readability,
pd.DataFrame(cv.cv_results_)
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_regularizer_strength | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | mean_test_score | std_test_score | rank_test_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.759626 | 0.687586 | 0.363913 | 0.440074 | 0.0001 | {'regularizer_strength': 0.0001} | -0.127700 | -0.118069 | -0.161886 | -0.131911 | -0.125092 | -0.132931 | 0.015160 | 1 |
| 1 | 2.978946 | 2.513113 | 0.006137 | 0.001482 | 1.0000 | {'regularizer_strength': 1.0} | -0.137128 | -0.128441 | -0.173226 | -0.138692 | -0.131281 | -0.141753 | 0.016175 | 2 |
The most informative for us is the 'mean_test_score' key, which shows the average of glm.score on each test-fold. Thus, higher is better, and we can see that the model performs better with lower regularization strength.
Find the best regularization strength!
As an exercise, spend 10 minutes trying to find the best regularization strength!
You should use the
glmmodel we defined in this section.You will need to redefine the
param_griddictionary, selecting different values for"regularizer_strength":
param_grid = {
"regularizer_strength": ...,
}
After defining
param_grid, reinitializecv(you can do so with the same arguments).Then call
cv.fitand re-runpd.DataFrame(cv.cv_results_)to summarize the results.
Who can find the best regularization strength?
If you finish early, try out different regularizers and try to find the best regularization strength for each of them.
Select basis#
We can do something similar to select the basis. In the above example, I just told you which basis function to use and how many of each. But, in general, you want to select those in a reasonable manner. Cross-validation to the rescue!
Unlike the glm objects, our basis objects are not scikit-learn compatible right out of the box. However, they can be made compatible by using the .to_transformer() method (or, equivalently, by using the TransformerBasis class)
You can (and should) do something similar to determine how many basis functions you need for each input.
NeMoS basis objects are not scikit-learn-compatible right out of the box.
But we have provided a simple method to make them so:
position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position").to_transformer()
# or equivalently:
position_basis = nmo.basis.TransformerBasis(nmo.basis.MSplineEval(n_basis_funcs=10, label="position"))
position_basis
'position': MSplineEval(n_basis_funcs=10, order=4)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| n_basis_funcs | 10 | |
| order | 4 | |
| bounds | None | |
| label | 'position' |
This gives the basis object the transform method, which is equivalent to compute_features. However, transformers have some limits:
This gives the basis object the
transformmethod, which is equivalent tocompute_features.However, transformers have some limits:
position_basis.transform(position)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 1
----> 1 position_basis.transform(position)
File ~/.venv/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:247, in TransformerBasis.transform(self, X, y)
210 """
211 Transform the data using the fitted basis functions.
212
(...) 244 >>> feature_transformed = transformer.transform(X)
245 """
246 self._check_initialized(self.basis)
--> 247 self._check_input(X, y)
248 # transpose does not work with pynapple
249 # can't use func(*X.T) to unwrap
250 return self.basis.compute_features(*self._unpack_inputs(X))
File ~/.venv/lib/python3.12/site-packages/nemos/basis/_transformer_basis.py:554, in TransformerBasis._check_input(self, X, y)
551 raise ValueError("The input must be a 2-dimensional array.")
553 elif ndim != 2:
--> 554 raise ValueError(
555 f"X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape {X.shape} instead."
556 )
558 if X.shape[1] != sum(self._input_shape_product):
559 raise ValueError(
560 f"Input mismatch: expected {sum(self._input_shape_product)} inputs, but got {X.shape[1]} "
561 f"columns in X.\nTo modify the required number of inputs, call `set_input_shape` before using "
562 f"`fit` or `fit_transform`."
563 )
ValueError: X must be 2-dimensional, shape (n_samples, n_features). The provided X has shape (19237,) instead.
Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.
In order to use a basis as a transformer, you’ll need to concatenate all your input in a single 2D array.
Transformers only accept 2d inputs, whereas nemos basis objects can accept inputs of any dimensionality.
position_basis.transform(position[:, np.newaxis])
Time (s) 0 1 2 3 4 ...
--------------- -------- ---------- ----------- ----------- --- -----
18193.603802655 0.162854 0.0063002 5.33261e-05 1.12053e-07 0 ...
18193.613802655 0.159556 0.00790116 8.52389e-05 2.27854e-07 0 ...
18193.623802655 0.156303 0.00946868 0.000124421 4.04332e-07 0 ...
18193.633802655 0.151106 0.0119478 0.000203448 8.54051e-07 0 ...
18193.643802655 0.145902 0.014398 0.000303689 1.57397e-06 0 ...
18193.653802655 0.141967 0.0162281 0.000394145 2.34623e-06 0 ...
18193.663802655 0.139957 0.0171555 0.000445419 2.8306e-06 0 ...
... ...
20123.332682821 0 0 0 0 0 ...
20123.342682821 0 0 0 0 0 ...
20123.352682821 0 0 0 0 0 ...
20123.362682821 0 0 0 0 0 ...
20123.372682821 0 0 0 0 0 ...
20123.382682821 0 0 0 0 0 ...
20123.392682821 0 0 0 0 0 ...
dtype: float32, shape: (19237, 10)
Other Caveats
If the basis has more than one component (for example, if it is the addition of two 1D bases), the transformer will expect an input shape of (n_sampels, 1) pre component. If that’s not the case, you’ll provide a different input shape by calling set_input_shape.
Case 1) One input per component:
# generate a composite basis
basis_2d = nmo.basis.MSplineEval(5) + nmo.basis.MSplineEval(5)
basis_2d = basis_2d.to_transformer()
# this will work: 1 input per component
x, y = np.random.randn(10, 1), np.random.randn(10, 1)
X = np.concatenate([x, y], axis=1)
result = basis_2d.transform(X)
Case 2) Multiple inputs per component.
If one or more basis process multiple inputs (multiple columns of the 2D array), trying to call the
transformmethod directly will lead to an error.This is because the basis doesn’t know which component should process which column.
# Assume 2 input for the first component and 3 for the second.
x, y = np.random.randn(10, 2), np.random.randn(10, 3)
X = np.concatenate([x, y], axis=1)
res = basis_2d.transform(X) # This will raise an exception!
To prevent that, use set_input_shape to define how many inputs each component should process.
# Set the expected input shape instead, different options:
# array
res1 = basis_2d.set_input_shape(x, y).transform(X)
# int
res2 = basis_2d.set_input_shape(2, 3).transform(X)
# tuple
res3 = basis_2d.set_input_shape((2,), (3,)).transform(X)
Let’s now create the composite basis for speed and position.
position_basis = nmo.basis.MSplineEval(n_basis_funcs=10, label="position")
speed_basis = nmo.basis.MSplineEval(n_basis_funcs=15, label="speed")
basis = position_basis + speed_basis
basis = basis.to_transformer()
basis
'(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 10 | |
| position__order | 4 | |
| position | 'position': M...s=10, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Let’s create a single TsdFrame to hold all our inputs:
Stack position and speed in a single TsdFrame to hold all our inputs:
transformer_input = nap.TsdFrame(
t=position.t,
d=np.stack([position, speed]).T,
time_support=position.time_support,
columns=["position", "speed"],
)
Pass this input to our transformed additive basis.
Our new additive transformer basis can then take these behavioral inputs and turn them into the model’s design matrix.
basis.transform(transformer_input)
Time (s) 0 1 2 3 4 ...
--------------- -------- ---------- ----------- ----------- --- -----
18193.603802655 0.162854 0.0063002 5.33261e-05 1.12053e-07 0 ...
18193.613802655 0.159556 0.00790116 8.52389e-05 2.27854e-07 0 ...
18193.623802655 0.156303 0.00946868 0.000124421 4.04332e-07 0 ...
18193.633802655 0.151106 0.0119478 0.000203448 8.54051e-07 0 ...
18193.643802655 0.145902 0.014398 0.000303689 1.57397e-06 0 ...
18193.653802655 0.141967 0.0162281 0.000394145 2.34623e-06 0 ...
18193.663802655 0.139957 0.0171555 0.000445419 2.8306e-06 0 ...
... ...
20123.332682821 0 0 0 0 0 ...
20123.342682821 0 0 0 0 0 ...
20123.352682821 0 0 0 0 0 ...
20123.362682821 0 0 0 0 0 ...
20123.372682821 0 0 0 0 0 ...
20123.382682821 0 0 0 0 0 ...
20123.392682821 0 0 0 0 0 ...
dtype: float32, shape: (19237, 25)
Pipelines#
We need one more step: scikit-learn cross-validation operates on an estimator, like our GLMs. if we want to cross-validate over the basis or its features, we need to combine our transformer basis with the estimator into a single estimator object. Luckily, scikit-learn provides tools for this: pipelines.
Pipelines are objects that accept a series of (0 or more) transformers, culminating in a final estimator. This is defined as a list of tuples, with each tuple containing a human-readable label and the object itself:
If we want to cross-validate over the basis, we need more one more step: combining the basis and the GLM into a single scikit-learn estimator.
Pipelines to the rescue!
# set the reg strength to the optimal
glm = nmo.glm.PopulationGLM(solver_name="LBFGS", solver_kwargs={"tol": 10**-12})
pipe = pipeline.Pipeline([
("basis", basis),
("glm", glm)
])
pipe
Pipeline(steps=[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 10 | |
| position__order | 4 | |
| position | 'position': M...s=10, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | UnRegularized() | |
| regularizer_strength | None | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
This pipeline object allows us to e.g., call fit using the initial input:
Pipeline runs
basis.transform, then passes that output toglm, so we can do everything in a single line:
pipe.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
Pipeline(steps=[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 10 | |
| position__order | 4 | |
| position | 'position': M...s=10, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | UnRegularized() | |
| regularizer_strength | None | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
We then visualize the predictions the same as before, using pipe instead of glm.
Visualize model predictions!
visualize_model_predictions(pipe, transformer_input)
Cross-validating on the basis#
Now that we have our pipeline estimator, we can cross-validate on any of its parameters!
pipe.steps
[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(
observation_model=PoissonObservations(),
inverse_link_function=exp,
regularizer=UnRegularized(),
solver_name='LBFGS',
solver_kwargs={'tol': 1e-12}
))]
Let’s cross-validate on the number of basis functions for the position basis, and the identity of the basis for the speed. That is:
Let’s cross-validate on:
The number of the basis functions of the position basis
The functional form of the basis for speed
Let’s retrieve the those attributes from the pipeline
# the label of the pipeline step retrieves the basis
print(pipe["basis"])
# the position basis can by retreived by its label
print("\n", pipe["basis"]["position"])
# the n_basis_funcs is an attribute
print("\n", pipe["basis"]["position"].n_basis_funcs)
# with the same syntax we can retreive the speed basis
print("\n", pipe["basis"]["speed"])
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))
Transformer('position': MSplineEval(n_basis_funcs=10, order=4))
10
Transformer('speed': MSplineEval(n_basis_funcs=15, order=4))
For scikit-learn parameter grids, we use __ to stand in for .:
Construct
param_grid, using__to stand in for.In scikit-learn pipelines, we access nested parameters using double underscores:
pipe["basis"]["position"].n_basis_funcs- normal Python syntax"basis__position__n_basis_funcs"- scikit-learn parameter grid syntax
param_grid = {
"basis__position__n_basis_funcs": [5, 10, 20],
"basis__speed": [nmo.basis.MSplineEval(15),
nmo.basis.BSplineEval(15),
nmo.basis.RaisedCosineLinearEval(15)],
}
Cross-validate as before:
cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
GridSearchCV(cv=5,
estimator=Pipeline(steps=[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))]),
param_grid={'basis__position__n_basis_funcs': [5, 10, 20],
'basis__speed': [MSplineEval(n_basis_funcs=15),
BSplineEval(n_basis_funcs=15),
RaisedCosineLinearEval(n_basis_funcs=15)]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 20 | |
| position__order | 4 | |
| position | 'position': M...s=20, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | UnRegularized() | |
| regularizer_strength | None | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
Investigate results:
pd.DataFrame(cv.cv_results_)
| mean_fit_time | std_fit_time | mean_score_time | std_score_time | param_basis__position__n_basis_funcs | param_basis__speed | params | split0_test_score | split1_test_score | split2_test_score | split3_test_score | split4_test_score | mean_test_score | std_test_score | rank_test_score | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1.922558 | 0.666500 | 0.527833 | 0.528948 | 5 | (((MSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 5, 'basis__... | -0.136246 | -0.116545 | -0.147549 | -0.122671 | -0.128592 | -0.130321 | 0.010800 | 3 |
| 1 | 1.495269 | 0.247778 | 0.037604 | 0.004972 | 5 | (((BSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 5, 'basis__... | -0.150079 | -0.155905 | -0.172195 | -0.126545 | -0.198609 | -0.160667 | 0.023965 | 8 |
| 2 | 1.495143 | 0.226423 | 0.144760 | 0.129065 | 5 | (((RaisedCosineLinearEval(n_basis_funcs=15, wi... | {'basis__position__n_basis_funcs': 5, 'basis__... | -0.163750 | -0.128224 | -0.184139 | -0.125410 | -0.216403 | -0.163585 | 0.034409 | 9 |
| 3 | 1.619937 | 0.185751 | 0.145199 | 0.046221 | 10 | (((MSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 10, 'basis_... | -0.127216 | -0.113708 | -0.148560 | -0.122756 | -0.132335 | -0.128915 | 0.011573 | 2 |
| 4 | 1.626782 | 0.260083 | 0.055368 | 0.008168 | 10 | (((BSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 10, 'basis_... | -0.153041 | -0.131869 | -0.165941 | -0.123137 | -0.176732 | -0.150144 | 0.020141 | 6 |
| 5 | 1.539392 | 0.141097 | 0.057420 | 0.009103 | 10 | (((RaisedCosineLinearEval(n_basis_funcs=15, wi... | {'basis__position__n_basis_funcs': 10, 'basis_... | -0.148822 | -0.126507 | -0.168947 | -0.122744 | -0.168339 | -0.147072 | 0.019742 | 4 |
| 6 | 1.914416 | 0.385908 | 0.283016 | 0.193542 | 20 | (((MSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 20, 'basis_... | -0.126687 | -0.112765 | -0.146401 | -0.124083 | -0.123058 | -0.126599 | 0.010976 | 1 |
| 7 | 1.756607 | 0.184487 | 0.079363 | 0.019208 | 20 | (((BSplineEval(n_basis_funcs=15, order=4)))) | {'basis__position__n_basis_funcs': 20, 'basis_... | -0.153473 | -0.132539 | -0.167847 | -0.124594 | -0.179855 | -0.151662 | 0.020777 | 7 |
| 8 | 1.638160 | 0.157081 | 0.102488 | 0.014956 | 20 | (((RaisedCosineLinearEval(n_basis_funcs=15, wi... | {'basis__position__n_basis_funcs': 20, 'basis_... | -0.146942 | -0.128252 | -0.172728 | -0.124165 | -0.167277 | -0.147873 | 0.019709 | 5 |
scikit-learn does not cache every model that it runs (that could get prohibitively large!), but it does store the best estimator, as the appropriately-named best_estimator_.
Can easily grab the best estimator, the pipeline that did the best:
best_estim = cv.best_estimator_
best_estim
Pipeline(steps=[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=20, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(), solver_kwargs={'tol': 1e-12}, solver_name='LBFGS'))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 20 | |
| position__order | 4 | |
| position | 'position': M...s=20, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | UnRegularized() | |
| regularizer_strength | None | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
We then visualize the predictions of best_estim the same as before.
Visualize model predictions!
visualize_model_predictions(best_estim, transformer_input)
Find the best basis!
As an exercise, spend 10 minutes exploring the possible basis objects and seeing which performs the best.
You should use the
pipeobject we defined in this section.You will need to redefine the
param_griddictionary, settingbasis__speedandbasis__position(or their attributes, e.g.,basis__position__n_basis_funcs) to a range of values. Remember that all combinations are tested, so if you e.g., select 5 choices for each, you’ll be testing 25 different combinations!
param_grid = {
"basis__position": ...,
"basis__speed": ...,
}
After defining
param_grid, reinitializecv(you can do so with the same arguments).Then call
cv.fitand re-runpd.DataFrame(cv.cv_results_)to summarize the results.Finally, visualize the best estimator with
visualize_model_predictions(cv.best_estimator_, transformer_input)
Who can find the best set of basis objects?
Feature selection#
Now that we understand how scikit-learn works with NeMoS, we can determine whether both position and speed are necessary inputs by performing feature selection.
Our goal is to compare alternative models: position + speed, position only, or speed only. However, scikit-learn’s cross-validation assumes that the input to the pipeline stays constant—only the hyperparameters change. So how can we compare models that require different features?
Here’s a clever NeMoS trick: we’ll create a “null” basis that produces zero features. This way, all models take the same 2-D input (position, speed), but some features become empty arrays. We can define this null basis using CustomBasis, which creates a basis from a list of functions.
Let’s move on to feature selection. Our goal is to compare alternative models: position + speed, position only, or speed only.
Problem: scikit-learn’s cross-validation assumes the pipeline input stays constant, but each model needs different features. How do we solve this?
Solution: Use a “null” basis that produces zero features!
We’ll create this null basis using
CustomBasis, which defines a basis from custom functions.
# define a function that creates an empty array (n_samples, 0)
def func(x):
return np.zeros((x.shape[0], 0))
# create a null transformer basis using the custom basis class
null_basis = nmo.basis.CustomBasis([func]).to_transformer()
# verify: this creates an empty feature array
null_basis.compute_features(position).shape
(19237, 0)
Why is this useful? We can combine null_basis with actual bases to create different models that all accept the same input!
Let’s define the bases for our three models:
Position + speed: combine position and speed bases
Position only: combine position basis with null basis (speed features is empty)
Speed only: combine null basis with speed basis (position features is empty)
# combine them to define each model
basis_all = (position_basis + speed_basis).to_transformer()
basis_position = (position_basis + null_basis).to_transformer()
basis_speed = (null_basis + speed_basis).to_transformer()
# assign labels (optional but helpful for readability)
basis_all.label = "position + speed"
basis_position.label = "position only"
basis_speed.label = "speed only"
These bases can all transform the same transformer_input (a TsdFrame with columns for position and speed), but they generate design matrices with different numbers of features:
# "position + speed" design: 25 features (10 + 15)
print("position + speed design matrix shape:")
print(basis_all.transform(transformer_input).shape)
# "position" design: 10 features (10 + 0)
print("\nposition design matrix shape:")
print(basis_position.transform(transformer_input).shape)
# "speed" design: 15 features (0 + 15)
print("\nspeed design matrix shape:")
print(basis_speed.transform(transformer_input).shape)
position + speed design matrix shape:
(19237, 25)
position design matrix shape:
(19237, 10)
speed design matrix shape:
(19237, 15)
To cross-validate over different basis compositions, we need to understand how they’re stored in our pipeline. The additive basis is stored as a basis attribute inside the TransformerBasis object:
# the "basis" step in our pipeline contains a TransformerBasis
# which has a "basis" attribute storing the actual additive basis
pipe["basis"].basis
'(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| label | '(position + speed)' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 10 | |
| position__order | 4 | |
| position | 'position': M...s=10, order=4) | |
| speed__bounds | None | |
| speed__label | 'speed' | |
| speed__n_basis_funcs | 15 | |
| speed__order | 4 | |
| speed | 'speed': MSpl...s=15, order=4) |
Now we can create a parameter grid for cross-validation. The key is the string "basis__basis":
First
basis: the name of the pipeline stepSecond
basis: the attribute of the TransformerBasis objectThis double-underscore notation is how scikit-learn accesses nested parameters
# create parameter grid with our three basis compositions
param_grid = {
"basis__basis": [
basis_all, # position + speed
basis_position, # position only
basis_speed # speed only
],
}
# define and fit GridSearchCV
cv = model_selection.GridSearchCV(pipe, param_grid, cv=cv_folds)
cv.fit(transformer_input, count)
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
/home/agent/workspace/rorse_ccn-software-feb-2026_main/.venv/lib/python3.12/site-packages/nemos/glm/glm.py:728: RuntimeWarning: The fit did not converge. Consider the following:
1) Enable float64 with ``jax.config.update('jax_enable_x64', True)``
2) Increase the max number of iterations or increase tolerance (if reasonable). These parameters can be specified by providing a ``solver_kwargs`` dictionary. For the available options see the ``self.solver.__init__`` docstrings.
warnings.warn(
GridSearchCV(cv=5,
estimator=Pipeline(steps=[('basis',
Transformer('(position + speed)': AdditiveBasis(
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))),
('glm',
PopulationGLM(inverse_link_function=<function exp at 0x7fd3d91c99e0>, observation_model=PoissonObservations(), regularizer=UnRegularized(...
basis1='position': MSplineEval(n_basis_funcs=10, order=4),
basis2=CustomBasis(
funcs=[func],
ndim_input=1,
output_shape=(0,),
pynapple_support=True,
is_complex=False
),
)),
Transformer('speed only': AdditiveBasis(
basis1=CustomBasis(
funcs=[func],
ndim_input=1,
output_shape=(0,),
pynapple_support=True,
is_complex=False
),
basis2='speed': MSplineEval(n_basis_funcs=15, order=4),
))]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
Parameters
| label | 'position only' | |
| position__bounds | None | |
| position__label | 'position' | |
| position__n_basis_funcs | 10 | |
| position__order | 4 | |
| position | 'position': M...s=10, order=4) | |
| CustomBasis__basis_kwargs | {} | |
| CustomBasis__funcs | [func] | |
| CustomBasis__is_complex | False | |
| CustomBasis__label | 'CustomBasis' | |
| CustomBasis__ndim_input | 1 | |
| CustomBasis__output_shape | (0,) | |
| CustomBasis__pynapple_support | True | |
| CustomBasis | CustomBasis( ...omplex=False ) |
Parameters
| observation_model | PoissonObservations() | |
| inverse_link_function | <function exp...x7fd3d91c99e0> | |
| regularizer | UnRegularized() | |
| regularizer_strength | None | |
| solver_name | 'LBFGS' | |
| solver_kwargs | {'tol': 1e-12} | |
| feature_mask | None |
cv_df = pd.DataFrame(cv.cv_results_)
# display the key columns: which basis was used, its score, and ranking
cv_df[["param_basis__basis", "mean_test_score", "rank_test_score"]]
| param_basis__basis | mean_test_score | rank_test_score | |
|---|---|---|---|
| 0 | Transformer('position + speed': AdditiveBasis(... | -0.128915 | 2 |
| 1 | Transformer('position only': AdditiveBasis(\n ... | -0.118960 | 1 |
| 2 | Transformer('speed only': AdditiveBasis(\n ... | -0.175172 | 3 |
Find the model!
In this section, we only compared a single choice of regularization strength and basis objects for each feature. As an exercise, spend 10 minutes combining what we learned here with the earlier sections: for each feature combination (position, speed, position + speed), try several different basis objects and, optionally, different regularization strengths.
Don’t forget to visualize your model’s predictions!
Who can find the best model?
Next Steps#
For the next project, you can use all the tools showcased here to find a better encoding model for these hippocampal neurons.
Suggestions:
Extend the model by including theta phase as a predictor
Use the NeMoS MultiplicativeBasis to capture interactions between theta phase and position