Skip to content

Note

Click here to download the full example code

Fit head-direction population

Learning objectives

  • Learn how to add history-related predictors to NeMoS GLM
  • Learn how to reduce over-fitting with Basis
  • Learn how to cross-validate with NeMoS + scikit-learn
import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap
import warnings
import workshop_utils

import nemos as nmo
from sklearn.model_selection import GridSearchCV

warnings.filterwarnings("ignore")

# configure pynapple to ignore conversion warning
nap.nap_config.suppress_conversion_warnings = True

# configure plots some
plt.style.use(workshop_utils.STYLE_FILE)

Data Streaming

  • Stream the head-direction neurons data
path = workshop_utils.fetch_data("Mouse32-140822.nwb")

Pynapple

  • load_file : open the NWB file and give a preview.
data = nap.load_file(path)
data
  • Load the units
spikes = data["units"]
spikes
  • Load the epochs and take only wakefulness
epochs = data["epochs"]
wake_ep = data["epochs"]["wake"]
  • Load the angular head-direction of the animal (in radians)
angle = data["ry"]
  • Select only those units that are in ADn
  • Restrict the activity to wakefulness (both the spiking activity and the angle)
spikes = spikes.getby_category("location")["adn"]
spikes = spikes.restrict(wake_ep).getby_threshold("rate", 1.0)
angle = angle.restrict(wake_ep)
  • Compute tuning curves as a function of head-direction
tuning_curves = nap.compute_1d_tuning_curves(
    group=spikes, feature=angle, nb_bins=61, minmax=(0, 2 * np.pi)
)
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(tuning_curves.iloc[:, 0])
ax[0].set_xlabel("Angle (rad)")
ax[0].set_ylabel("Firing rate (Hz)")
ax[1].plot(tuning_curves.iloc[:, 1])
ax[1].set_xlabel("Angle (rad)")
plt.tight_layout()
  • Let's visualize the data at the population level.
fig = workshop_utils.plotting.plot_head_direction_tuning(
    tuning_curves, spikes, angle, threshold_hz=1, start=8910, end=8960
)
  • Take the first 3 minutes of wakefulness to speed up optimization
wake_ep = nap.IntervalSet(
    start=wake_ep.start[0], end=wake_ep.start[0] + 3 * 60
)
  • bin the spike trains in 10 ms bin
bin_size = 0.01
count = spikes.count(bin_size, ep=wake_ep)
  • sort the neurons by their preferred direction using pandas
pref_ang = tuning_curves.idxmax()

count = nap.TsdFrame(
    t=count.t,
    d=count.values[:, np.argsort(pref_ang.values)],
)

NeMoS

Self-Connected Single Neuron

  • Start with modeling a self-connected single neuron
  • Select a neuron
  • Select the first 1.2 seconds for visualization
# select a neuron's spike count time series
neuron_count = count[:, 0]

# restrict to a smaller time interval
epoch_one_spk = nap.IntervalSet(
    start=count.time_support.start[0], end=count.time_support.start[0] + 1.2
)
  • visualize the spike count time course
# {.keep-code}
# set the size of the spike history window in seconds
window_size_sec = 0.8

workshop_utils.plotting.plot_history_window(
    neuron_count, epoch_one_spk, window_size_sec
)
  • Form a predictor matrix by vertically stacking all the windows (you can use a convolution).
# convert the prediction window to bins (by multiplying with the sampling rate)
window_size = int(window_size_sec * neuron_count.rate)

# convolve the counts with the identity matrix.
input_feature = nmo.convolve.create_convolutional_predictor(
    np.eye(window_size), neuron_count
)

# print the NaN indices along the time axis
print("NaN indices:\n", np.where(np.isnan(input_feature[:, 0]))[0])
  • Check the shape of the counts and features.
print(f"Time bins in counts: {neuron_count.shape[0]}")
print(f"Convolution window size in bins: {window_size}")
print(f"Feature shape: {input_feature.shape}")
  • Plot the convolution output.
# {.keep-code}
suptitle = "Input feature: Count History"
neuron_id = 0
workshop_utils.plotting.plot_features(input_feature, count.rate, suptitle)
# The resulting feature dimension is 80, because our bin size was 0.01 sec and the window size is 0.8 sec.

# We can learn these weights by maximum likelihood by fitting a GLM.

Fitting the Model

  • Split your epochs in two for validation purposes.
# {.keep-code}
# construct the train and test epochs
duration = neuron_count.time_support.tot_length()
start = neuron_count.time_support["start"]
end = neuron_count.time_support["end"]
first_half = nap.IntervalSet(start, start + duration / 2)
second_half = nap.IntervalSet(start + duration / 2, end)
  • Fit a GLM to the first half.
# define the GLM object
model = nmo.glm.GLM()

# Fit over the training epochs
model.fit(
    input_feature.restrict(first_half),
    neuron_count.restrict(first_half)
)
  • Plot the weights.
# {.keep-code}
workshop_utils.plotting.plot_and_compare_weights(
    [model.coef_], ["GLM raw history 1st Half"], count.rate)

Inspecting the results

  • Fit on the other half.
model_second_half = nmo.glm.GLM()

model_second_half.fit(
    input_feature.restrict(second_half),
    neuron_count.restrict(second_half)
)
  • Compare results.
# {.keep-code}
workshop_utils.plotting.plot_and_compare_weights(
    [model.coef_, model_second_half.coef_],
    ["GLM raw history 1st Half", "GLM raw history 2nd Half"],
    count.rate)

Reducing feature dimensionality

  • Visualize the raised cosine basis.
# {.keep-code}
workshop_utils.plotting.plot_basis()
# a basis object can be instantiated in "conv" mode for convolving  the input.
basis = nmo.basis.RaisedCosineBasisLog(
    n_basis_funcs=8, mode="conv", window_size=window_size
)

# time takes equi-spaced values between 0 and 1, we could multiply by the
# duration of our window to scale it to seconds.
time = window_size_sec * np.arange(window_size)
  • Convolve the counts with the basis functions.
# equivalent to
# `nmo.convolve.create_convolutional_predictor(basis_kernels, neuron_count)`
conv_spk = basis.compute_features(neuron_count)

print(f"Raw count history as feature: {input_feature.shape}")
print(f"Compressed count history as feature: {conv_spk.shape}")
  • Visualize the output.
# {.keep-code}
# Visualize the convolution results
epoch_one_spk = nap.IntervalSet(8917.5, 8918.5)
epoch_multi_spk = nap.IntervalSet(8979.2, 8980.2)

workshop_utils.plotting.plot_convolved_counts(
    neuron_count, conv_spk, epoch_one_spk, epoch_multi_spk
)

Fit and compare the models

  • Fit the model using the compressed features.
model_basis = nmo.glm.GLM()

# use restrict on interval set training
model_basis.fit(
    conv_spk.restrict(first_half),
    neuron_count.restrict(first_half)
)
print(model_basis.coef_)
  • Reconstruct the history filter.
_, basis_kernels = basis.evaluate_on_grid(window_size)
self_connection = np.matmul(basis_kernels, model_basis.coef_)

print(self_connection.shape)
  • Fit the other half of the data.
model_basis_second_half = nmo.glm.GLM(
    regularizer=nmo.regularizer.UnRegularized("LBFGS")
)
model_basis_second_half.fit(
    conv_spk.restrict(second_half), neuron_count.restrict(second_half)
)

# compute responses for the 2nd half fit
self_connection_second_half = np.matmul(basis_kernels, model_basis_second_half.coef_)
  • Plot and compare the results.
workshop_utils.plotting.plot_and_compare_weights(
    [model.coef_, model_second_half.coef_, self_connection, self_connection_second_half],
    ["GLM raw history 1st Half", "GLM raw history 2nd half", "GLM basis 1st half", "GLM basis 2nd half"],
    count.rate
)
  • Predict the rates
rate_basis = model_basis.predict(conv_spk) * conv_spk.rate
rate_history = model.predict(input_feature) * conv_spk.rate
ep = nap.IntervalSet(start=8819.4, end=8821)
  • Plot the results.
# {.keep-code}
# plot the rates
workshop_utils.plotting.plot_rates_and_smoothed_counts(
    neuron_count.restrict(ep),
    {
        "Self-connection raw history":rate_history,
        "Self-connection bsais": rate_basis
    }
)

All-to-all Connectivity

Preparing the features

  • Convolve all counts.
  • Print the output shape
# convolve all the neurons
convolved_count = basis.compute_features(count)
# shape should be `(n_samples, n_basis_func * n_neurons)`
print(f"Convolved count shape: {convolved_count.shape}")

Fitting the Model

  • Fit a PopulationGLM
  • Use Ridge regularization with a regularizer_strength=0.1
  • Print the shape of the estimated coefficients.
model = nmo.glm.PopulationGLM(
    regularizer=nmo.regularizer.Ridge("LBFGS", regularizer_strength=0.1)
).fit(convolved_count, count)

print(f"Model coefficients shape: {model.coef_.shape}")

Comparing model predictions.

  • Predict the firing rate of each neuron
  • Convert the rate from spike/bin to spike/sec
# predict the rate (counts are already sorted by tuning prefs)
predicted_firing_rate = model.predict(convolved_count) * conv_spk.rate
  • Visualize the predicted rate and tuning function.
# {.keep-code}
# use pynapple for time axis for all variables plotted for tick labels in imshow
workshop_utils.plotting.plot_head_direction_tuning_model(
    tuning_curves, predicted_firing_rate, spikes, angle, threshold_hz=1,
    start=8910, end=8960, cmap_label="hsv"
)
  • Visually compare all the models.
# {.keep-code}
# mkdocs_gallery_thumbnail_number = 2
workshop_utils.plotting.plot_rates_and_smoothed_counts(
    neuron_count,
    {"Self-connection: raw history": rate_history,
     "Self-connection: bsais": rate_basis,
     "All-to-all: basis": predicted_firing_rate[:, 0]}
)

Visualizing the connectivity

  • Compute tuning curves from the predicted rates using pynapple.
tuning = nap.compute_1d_tuning_curves_continuous(predicted_firing_rate,
                                                 feature=angle,
                                                 nb_bins=61,
                                                 minmax=(0, 2 * np.pi))
  • Extract the weights and store it in an array, shape (num_neurons, num_neurons, num_features).
n_neurons = count.shape[1]
weights = model.coef_.reshape(n_neurons, basis.n_basis_funcs, n_neurons)
  • Multiply the weights by the basis, to get the history filters.
responses = np.einsum("jki,tk->ijt", weights, basis_kernels)

print(responses.shape)
  • Plot the connectivity map.
# {.keep-code}
workshop_utils.plotting.plot_coupling(responses, tuning)

K-fold Cross-Validation

Grid Search Cross Validation
K-fold cross-validation (from scikit-learn docs)

K-fold with NeMoS and scikit-learn

  • Instantiate the PopulationGLM
  • Define a grid of regularization strengths.
  • Instantiate and fit the GridSearchCV with 2 folds.
# {.keep-code}
# define the model
model = nmo.glm.PopulationGLM(
    regularizer=nmo.regularizer.Ridge("LBFGS")
)

# define a grid of parameters for the search
param_grid = dict(regularizer__regularizer_strength=np.logspace(-3, 0, 4))
print(param_grid)

# define a GridSearch cross-validation from scikit-learn
# with 2-folds
k_fold = GridSearchCV(model, param_grid=param_grid, cv=2)
  • Run cross-validation!
# {.keep-code}
# fit the cross-validated model
k_fold.fit(convolved_count, count)
  • Print the best parameters.
# {.keep-code}
print(f"Best regularization strength: "
      f"{k_fold.best_params_['regularizer__regularizer_strength']}")

Exercises

  • Plot the weights and rate predictions.
  • What happens if you use 5 folds?
  • What happen if you cross-validate each neuron individually? Do you select the same hyperparameter for every neuron or not?

Total running time of the script: ( 0 minutes 0.000 seconds)

Download Python source code: 06_head_direction_code.py

Download Jupyter notebook: 06_head_direction_code.ipynb

Gallery generated by mkdocs-gallery