Note
Click here to download the full example code
Fit head-direction population
Learning objectives
- Learn how to add history-related predictors to NeMoS GLM
- Learn how to reduce over-fitting with
Basis
- Learn how to cross-validate with NeMoS + scikit-learn
import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap
import warnings
import workshop_utils
import nemos as nmo
from sklearn.model_selection import GridSearchCV
warnings.filterwarnings("ignore")
# configure pynapple to ignore conversion warning
nap.nap_config.suppress_conversion_warnings = True
# configure plots some
plt.style.use(workshop_utils.STYLE_FILE)
Data Streaming
- Stream the head-direction neurons data
path = workshop_utils.fetch_data("Mouse32-140822.nwb")
Pynapple
load_file
: open the NWB file and give a preview.
data = nap.load_file(path)
data
- Load the units
spikes = data["units"]
spikes
- Load the epochs and take only wakefulness
epochs = data["epochs"]
wake_ep = data["epochs"]["wake"]
- Load the angular head-direction of the animal (in radians)
angle = data["ry"]
- Select only those units that are in ADn
- Restrict the activity to wakefulness (both the spiking activity and the angle)
spikes = spikes.getby_category("location")["adn"]
spikes = spikes.restrict(wake_ep).getby_threshold("rate", 1.0)
angle = angle.restrict(wake_ep)
- Compute tuning curves as a function of head-direction
tuning_curves = nap.compute_1d_tuning_curves(
group=spikes, feature=angle, nb_bins=61, minmax=(0, 2 * np.pi)
)
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(tuning_curves.iloc[:, 0])
ax[0].set_xlabel("Angle (rad)")
ax[0].set_ylabel("Firing rate (Hz)")
ax[1].plot(tuning_curves.iloc[:, 1])
ax[1].set_xlabel("Angle (rad)")
plt.tight_layout()
- Let's visualize the data at the population level.
fig = workshop_utils.plotting.plot_head_direction_tuning(
tuning_curves, spikes, angle, threshold_hz=1, start=8910, end=8960
)
- Take the first 3 minutes of wakefulness to speed up optimization
wake_ep = nap.IntervalSet(
start=wake_ep.start[0], end=wake_ep.start[0] + 3 * 60
)
- bin the spike trains in 10 ms bin
bin_size = 0.01
count = spikes.count(bin_size, ep=wake_ep)
- sort the neurons by their preferred direction using pandas
pref_ang = tuning_curves.idxmax()
count = nap.TsdFrame(
t=count.t,
d=count.values[:, np.argsort(pref_ang.values)],
)
NeMoS
Self-Connected Single Neuron
- Start with modeling a self-connected single neuron
- Select a neuron
- Select the first 1.2 seconds for visualization
# enter code here
- visualize the spike count time course
# set the size of the spike history window in seconds
window_size_sec = 0.8
workshop_utils.plotting.plot_history_window(
neuron_count, epoch_one_spk, window_size_sec
)
- Form a predictor matrix by vertically stacking all the windows (you can use a convolution).
# enter code here
- Check the shape of the counts and features.
# enter code here
- Plot the convolution output.
suptitle = "Input feature: Count History"
neuron_id = 0
workshop_utils.plotting.plot_features(input_feature, count.rate, suptitle)
# enter code here
Fitting the Model
- Split your epochs in two for validation purposes.
# construct the train and test epochs
duration = neuron_count.time_support.tot_length()
start = neuron_count.time_support["start"]
end = neuron_count.time_support["end"]
first_half = nap.IntervalSet(start, start + duration / 2)
second_half = nap.IntervalSet(start + duration / 2, end)
- Fit a GLM to the first half.
# enter code here
- Plot the weights.
workshop_utils.plotting.plot_and_compare_weights(
[model.coef_], ["GLM raw history 1st Half"], count.rate)
Inspecting the results
- Fit on the other half.
# enter code here
- Compare results.
workshop_utils.plotting.plot_and_compare_weights(
[model.coef_, model_second_half.coef_],
["GLM raw history 1st Half", "GLM raw history 2nd Half"],
count.rate)
Reducing feature dimensionality
- Visualize the raised cosine basis.
workshop_utils.plotting.plot_basis()
# enter code here
- Convolve the counts with the basis functions.
# enter code here
- Visualize the output.
# Visualize the convolution results
epoch_one_spk = nap.IntervalSet(8917.5, 8918.5)
epoch_multi_spk = nap.IntervalSet(8979.2, 8980.2)
workshop_utils.plotting.plot_convolved_counts(
neuron_count, conv_spk, epoch_one_spk, epoch_multi_spk
)
Fit and compare the models
- Fit the model using the compressed features.
# enter code here
- Reconstruct the history filter.
# enter code here
- Fit the other half of the data.
# enter code here
- Plot and compare the results.
# enter code here
- Predict the rates
# enter code here
- Plot the results.
# plot the rates
workshop_utils.plotting.plot_rates_and_smoothed_counts(
neuron_count.restrict(ep),
{
"Self-connection raw history":rate_history,
"Self-connection bsais": rate_basis
}
)
All-to-all Connectivity
Preparing the features
- Convolve all counts.
- Print the output shape
# enter code here
Fitting the Model
- Fit a
PopulationGLM
- Use Ridge regularization with a
regularizer_strength=0.1
- Print the shape of the estimated coefficients.
# enter code here
Comparing model predictions.
- Predict the firing rate of each neuron
- Convert the rate from spike/bin to spike/sec
# enter code here
- Visualize the predicted rate and tuning function.
# use pynapple for time axis for all variables plotted for tick labels in imshow
workshop_utils.plotting.plot_head_direction_tuning_model(
tuning_curves, predicted_firing_rate, spikes, angle, threshold_hz=1,
start=8910, end=8960, cmap_label="hsv"
)
- Visually compare all the models.
# mkdocs_gallery_thumbnail_number = 2
workshop_utils.plotting.plot_rates_and_smoothed_counts(
neuron_count,
{"Self-connection: raw history": rate_history,
"Self-connection: bsais": rate_basis,
"All-to-all: basis": predicted_firing_rate[:, 0]}
)
Visualizing the connectivity
- Compute tuning curves from the predicted rates using pynapple.
# enter code here
- Extract the weights and store it in an array, shape (num_neurons, num_neurons, num_features).
# enter code here
- Multiply the weights by the basis, to get the history filters.
# enter code here
- Plot the connectivity map.
workshop_utils.plotting.plot_coupling(responses, tuning)
K-fold Cross-Validation
K-fold cross-validation (from scikit-learn docs)
K-fold with NeMoS and scikit-learn
- Instantiate the PopulationGLM
- Define a grid of regularization strengths.
- Instantiate and fit the GridSearchCV with 2 folds.
# define the model
model = nmo.glm.PopulationGLM(
regularizer=nmo.regularizer.Ridge("LBFGS")
)
# define a grid of parameters for the search
param_grid = dict(regularizer__regularizer_strength=np.logspace(-3, 0, 4))
print(param_grid)
# define a GridSearch cross-validation from scikit-learn
# with 2-folds
k_fold = GridSearchCV(model, param_grid=param_grid, cv=2)
- Run cross-validation!
# fit the cross-validated model
k_fold.fit(convolved_count, count)
- Print the best parameters.
print(f"Best regularization strength: "
f"{k_fold.best_params_['regularizer__regularizer_strength']}")
Exercises
- Plot the weights and rate predictions.
- What happens if you use 5 folds?
- What happen if you cross-validate each neuron individually? Do you select the same hyperparameter for every neuron or not?
Total running time of the script: ( 0 minutes 0.000 seconds)
Download Python source code: 06_head_direction_users.py