Hide code cell source

%matplotlib inline
%load_ext autoreload
%autoreload 2
import warnings

warnings.filterwarnings(
    "ignore",
    message="plotting functions contained within `_documentation_utils` are intended for nemos's documentation.",
    category=UserWarning,
)

Download

This notebook can be downloaded as visual_coding-users.ipynb. See the button at the top right to download as markdown or pdf.

Exploring the Visual Coding Dataset#

This notebook has had all its explanatory text removed and has not been run. It is intended to be downloaded and run locally (or on the provided binder) while listening to the presenter’s explanation. In order to see the fully rendered of this notebook, go here

This notebook serves as a group project: in groups of 4 or 5, you will analyze data from the Visual Coding - Neuropixels dataset, published by the Allen Institute. This dataset uses extracellular electrophysiology probes to record spikes from multiple regions in the brain during passive visual stimulation.

To start, we will focus on the activity of neurons in the visual cortex (VISp) during passive exposure to full-field flashes of color either black (coded as “-1.0”) or white (coded as “1.0”) in a gray background. If you have time, you can apply the same procedure to other stimuli or brain areas.

For this exercise, you will:

  • Compute Peristimulus Time Histograms (PSTHs) and select relevant neurons to analyze using pynapple.

  • Fit GLMs to these neurons using nemos.

As this is the last notebook, the instructions are a bit more hands-off: you will make more of the analysis and modeling decisions yourselves. As a group, you will use your neuroscience knowledge and the skills gained over this workshop to decide:

  • How to select relevant neurons.

  • How to avoid overfitting.

  • What features to include in your GLMs.

  • Which basis functions (and parameters) to use for each feature.

  • How to regularize your features.

  • How to evaluate your model.

At the end of this session, we will regroup to discuss the decisions people made and evaluate each others’ models.

# Import everything
import jax
import matplotlib.pyplot as plt
import numpy as np
import pynapple as nap

import nemos as nmo

# some helper plotting functions
from nemos import _documentation_utils as doc_plots
import workshop_utils

import matplotlib as mpl
from matplotlib.ticker import MaxNLocator
from scipy.stats import gaussian_kde
from matplotlib.patches import Patch

# configure plots some
plt.style.use(nmo.styles.plot_style)

Downloading and preparing data#

In this section, we will download the data from DANDI and extract the relevant parts for analysis and modeling. This section is largely presented to you as is, so that you can get to the substantive sections more quickly.

First we download and load the data into pynapple.

# Dataset information
dandiset_id = "000021"
dandi_filepath = "sub-726298249/sub-726298249_ses-754829445.nwb"

# Download the data using NeMoS
io = nmo.fetch.download_dandi_data(dandiset_id, dandi_filepath)

# load data using pynapple
data = nap.NWBFile(io.read(), lazy_loading=True)

# grab the spiking data
units = data["units"]

# map from electrodes to brain area
channel_probes = {}
for elec in data.nwb.electrodes:
    channel_id = elec.index[0]
    location = elec["location"].values[0]
    channel_probes[channel_id] = location

# Add a new column to include location in our spikes TsGroup
units.brain_area = [channel_probes[int(ch_id)] for ch_id in units.peak_channel_id]

# drop unnecessary metadata
units.restrict_info(["rate", "quality", "brain_area"])

Now that we have our spiking data, let’s restrict our dataset to the relevant part.

Visual stimuli set

During the flashes presentation trials, mice were exposed to white or black full-field flashes in a gray background, each lasting 250 ms, and separated by a 2 second inter-trial interval. In total, they were exposed to 150 flashes (75 black, 75 white).

flashes = data["flashes_presentations"]
flashes.restrict_info(["color"])

# create a separate object for black and white flashes
flashes_white = flashes[flashes["color"] == "1.0"]
flashes_black = flashes[flashes["color"] == "-1.0"]

Let’s visualize our stimuli:

Hide code cell source

n_flashes = 5
n_seconds = 13
offset = .5

start = data["flashes_presentations"]["start"].min() - offset
end = start + n_seconds

fig, ax = plt.subplots(figsize = (17, 4))
for flash, c in zip([flashes_white, flashes_black], ["silver", "black"]):
    for fl in flash[:n_flashes]:
        ax.axvspan(fl.start[0], fl.end[0], color=c, alpha=.4, ec=c)      

plt.xlabel("Time (s)")
plt.ylabel("Absent = 0, Present = 1")
ax.set_title("Stimuli presentation")
ax.yaxis.set_major_locator(MaxNLocator(integer=True))

plt.xlim(start-.1,end)
    

Preliminary analyses and neuron selection#

From here on out, you will write the code yourself. This first section will involve us doing some preliminary analyses to find the neurons that are most visually responsive; these are the neurons we will fit our GLM to.

First, let’s construct a IntervalSet called extended_flashes which contains the peristimulus time. Right now, our flashes IntervalSet defines the start and end time for the flashes. In order to make sure we can model the pre-stimulus baseline and any responses to the stimulus being turned off, we would like to expand these intervals to go from 500 msecs before the start of the stimuli to 500 msecs after the end.

This IntervalSet should be the same shape as flashes and have the same metadata columns.

dt = .5
extended_flashes =

If you have succeeded, the following should pass:

assert extended_flashes.shape == flashes.shape
assert all(extended_flashes.metadata == flashes.metadata)
assert all(extended_flashes.start == flashes.start - .5)
assert all(extended_flashes.end == flashes.end + .5)

Now, create two separate IntervalSet objects, extended_flashes_black and extended_flashes_white, which contain this info for only the black and the white flashes, respectively.

extended_flashes_white =
extended_flashes_black =
# This should all pass if you created the IntervalSet correctly
assert extended_flashes_white.shape == flashes_white.shape
assert all(extended_flashes_white.metadata == flashes_white.metadata)
assert all(extended_flashes_white.start == flashes_white.start - .5)
assert all(extended_flashes_white.end == flashes_white.end + .5)
assert extended_flashes_black.shape == flashes_black.shape
assert all(extended_flashes_black.metadata == flashes_black.metadata)
assert all(extended_flashes_black.start == flashes_black.start - .5)
assert all(extended_flashes_black.end == flashes_black.end + .5)

Now, select your neurons. There are four criteria we want to use:

  1. Brain area: we are interested in analyzing VISp units for this tutorial

  2. Quality: we will only select “good” quality units. If you’re curious, you can (optionally) read more how about the Allen Institute defines quality.

  3. Firing rate: overall, we want units with a firing rate larger than 2Hz around the presentation of stimuli

  4. Responsiveness: we want units that actually respond to changes in the visual stimuli, i.e., their firing rate changes as a result of the stimulus.

Create a new TsGroup, selected_units, which includes only those units that meet the first three criteria, then check that it passes the assertion block.

Restrict!

Don’t forget when selecting based on firing rate that we want neurons whose firing rate is above the threshold around the presentation of the stimuli! This means you should use restrict()! If only we had a useful IntervalSet lying around…

selected_units =
assert len(selected_units) == 92

Now, in order to determine the responsiveness of the units, it’s helpful to use the compute_perievent() function: this will align units’ spiking timestamps with the onset of the stimulus repetitions and take an average over them.

Let’s use that function to construct two separate perievent dictionaries, one aligned to the start of the white stimuli, one aligned to the start of the black, and they should run from 250 msec before to 500 msec after the event.

peri_white =
peri_black =

Visualizing these perievents can help us determine which units to include. The following helper function should help.

Hide code cell source

def plot_raster_psth(peri, units, color_flashes, n_units=9, start_unit=0, bin_size=.005, smoothing=0.015):
    """
    Plot perievent time histograms (PSTHs) and raster plots for multiple units.

    Parameters:
    -----------
    peri : dict
        Dictionary mapping unit names to binned spike count peri-stimulus data (e.g., binned time series).
    units : dict
        Dictionary of neural units, e.g., spike trains or trial-aligned spike events.
    color_flashes : str
        A label indicating the flash color condition ('black' or 'white'), used for visual styling.
    n_units : int
        The number of units to visualize.
    start_unit : int
        The index of the unit to start with.
    bin_size : float
        Size of the bin used for spike count computation (in seconds).
    smoothing : float
        Standard deviation for Gaussian smoothing of the PSTH traces.
    """

    # Layout setup: 9 columns (units), 2 rows (split vertically into PSTH and raster plot)
    n_cols = n_units
    n_rows = 2
    fig, axes = plt.subplots(n_rows, n_cols,
                             figsize=(n_cols*2, 4))

    colors = plt.cm.tab10.colors

    # Extract unit names for iteration
    units_list = list(units.keys())[start_unit:start_unit+n_units]

    for i, unit in enumerate(units_list):
        u = peri[unit]
        line_color = colors[i % len(colors)]
        ax = axes[0, i]

        # Plot PSTH (smoothed firing rate)
        ax.plot(
            (np.mean(u.count(bin_size), 1) / bin_size).smooth(std=smoothing),
            linewidth=2,
            color=line_color
        )
        ax.axvline(0.0)  # Stimulus onset line

        span_color = "black" if color_flashes == "black" else "silver"
        ax.axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")  # Stim duration
        ax.set_xlim(-0.25, 0.50)
        ax.set_title(f'{unit}')

        # Plot raster
        ax = axes[1, i]
        ax.plot(u.to_tsd(), "|", markersize=1, color=line_color, mew=2)
        ax.axvline(0.0)
        ax.axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")
        ax.set_ylim(0, 75)
        ax.set_xlim(-0.25, 0.50)

    # Y-axis and title annotations
    axes[0, 0].set_ylabel("Rate (Hz)")
    axes[1, 0].set_ylabel("Trial")
    if n_rows > 2:
        axes[2, 0].set_ylabel("Rate (Hz)")
        axes[3, 0].set_ylabel("Trial")
    fig.text(0.5, 0.00, 'Time from stim(s)', ha='center')
    fig.text(0.5, 1.00, f'PSTH & Spike Raster Plot - {color_flashes} flashes', ha='center')
    plt.tight_layout()
# called like this, the function will visualize the first 9 units. play with the n_units
# and start_unit arguments to see the other units.
plot_raster_psth(peri_white, selected_units, "white", n_units=9, start_unit=0)
plot_raster_psth(peri_black, selected_units, "black", n_units=9, start_unit=0)

You could manually visualize each of our units and select those that appear, from their PSTH to be responsive.

However, it would be easier to scale (and more reproducible) if you came up with some measure of responsiveness. So how do we compute something that captures “this neuron responds to visual stimuli”?

You should be able to do this using a function that iterates over the peri_white and peri_black dictionaries, returning a single float for each unit.

Let’s aim to pick around 20 neurons.

If you’re having trouble coming up with one that seems reasonable, expand the following admonition.

# enter code here

Let’s visualize the selected units PSTHs to make sure they all look reasonable:

print(f"Remaining units: {len(selected_units)}")
peri_white = {k: peri_white[k] for k in selected_units.index}
peri_black = {k: peri_black[k] for k in selected_units.index}

plot_raster_psth(peri_black, selected_units, "black", n_units=len(peri_black))
plot_raster_psth(peri_white, selected_units, "white", n_units=len(peri_white))

Avoiding overfitting#

As we’ve seen throughout this workshop, it is important to avoid overfitting your model. We’ve covered two strategies for doing so: either separate your dataset into train and test subsets or set up a cross-validation scheme. Pick one of these approaches and use it when fitting your GLM model in the next section.

You might find it helpful to refer back to the sklearn notebook and / or to use the following pynapple functions: set_diff(), union(), restrict().

# enter code here

Fit a GLM#

In this section, you will use nemos to build a GLM. There are a lot of scientific decisions to be made here, so we suggest starting simple and then adding complexity. Construct a design matrix with a single predictor, using a basis of your choice, then construct, fit, and score your model to a single neuron (remembering to either use your train/test or cross-validation to avoid overfitting). Then add regularization to your GLM. Then return to the beginning and add more predictors. Then fit all the neurons. Then evaluate what basis functions and parameters are best for your predictors. Then use the tricks we covered in sklearn to evaluate whether which predictors are necessary for your model, which are the most important.

You don’t have to exactly follow those steps, but make sure you can go from beginning to end before getting too complex.

Good luck and we look forward to seeing what you come up with!

Prepare data#

  • Create spike count data.

# enter code here

Construct design matrix#

  • Decide on feature(s)

  • Decide on basis

  • Construct design matrix

# enter code here

Construct and fit your model#

  • Decide on regularization

  • Initialize GLM

  • Call fit

  • Visualize result on PSTHs

# enter code here

Here’s a helper function for plotting the PSTH of the data and predictions (for one or multiple neurons), which you may find helpful for visualizing your model performance.

Hide code cell source

def plot_pop_psth(
        peri,
        color_flashes,
        unit_id=None,
        bin_size=0.005,
        smoothing=0.015,
        **peri_others
        ):
    """Plot perievent time histograms (PSTHs) and raster plots for multiple units.

    Model predictions should be passed as additional keyword arguments. The key will be
    used as the label, and the value should be a 2-tuple of `(style, peri)`, where
    `style` is a matplotlib style (e.g., "blue" or "--") and `peri` is a PSTH
    dictionary, as returned by `compute_perievent_continuous`.
    
    Parameters:
    -----------
    peri : dict or TsGroup
        Dictionary mapping unit names to binned spike count peri-stimulus data (e.g., binned time series).
    color_flashes : str
        A label indicating the flash color condition ('black' or other), used for visual styling.
    bin_size : float
        Size of the bin used for spike count computation (in seconds).
    smoothing : float
        Standard deviation for Gaussian smoothing of the PSTH traces.
    peri_others : tuple
        Model PSTHs to plot. See above for description

    """
    if not isinstance(peri, dict):
        peri = {0: peri}
        
    n_cols = len(peri)
    fig, axes = plt.subplots(1, n_cols,
                             figsize=(2.5 * n_cols, 2.5))
    if n_cols == 1:
        axes = [axes]

    for i, (unit, u) in enumerate(peri.items()):
        try:
            ax = axes[i]
        except TypeError:
            # then there's only set of axes
            ax = axes
        # Plot PSTH (smoothed firing rate)
        ax.plot(
            (np.mean(u.count(bin_size), 1) / bin_size).smooth(std=smoothing),
            linewidth=2,
            color="black",
            label="Observed"
        )
        ax.axvline(0.0)  # Stimulus onset line
        span_color = "black" if color_flashes == "black" else "silver"
        ax.axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")  # Stim duration
        ax.set_xlim(-0.25, 0.50)
        ax.set_title(f'{unit}')
        for (key, (color, peri_pred)) in peri_others.items():
            try:
                p = peri_pred[:, :, i]
            except IndexError:
                p = peri_pred
            ax.plot(
            (np.mean(p, axis=1)),
            linewidth=1.5,
            color=color,
            label=key.capitalize()
            )

    # Y-axis and title annotations
    axes[0].set_ylabel("Rate (Hz)")
    fig.legend(*ax.get_legend_handles_labels())
    fig.text(0.5, 0.00, 'Time from stim(s)', ha='center')
    fig.text(0.5, 1.00, f'PSTH - {color_flashes} flashes', ha='center')
    plt.tight_layout()

Score your model#

  • We trained on the train set, so now we score on the test set. (Or use cross-validation.)

  • Get a score for your model that you can use to compare across the modeling choices outlined above.

# enter code here

Try to improve your model?#

  • Go back to the beginning of this section and try to improve your model’s performance (as reflected by increased score).

  • Keep track of what you’ve tried and their respective scores.

# Example construction of dataframe.
# In this:
# - additive_basis is the single AdditiveBasis object we used to construct the entire design matrix
# - model is the GLM we fit to a single neuron
# - unit_id is the int identifying the neuron we're fitting
# - score is the float giving the model score, summarizing model performance (on the test set)
import pandas as pd
data = [
    {
        "model_id": 0,
        "regularizer": model.regularizer.__class__.__name__,
        "regularizer_strength": model.regularizer_strength,
        "solver": model.solver_name,
        "score": score,
        "n_predictors": len(additive_basis),
        "unit": unit_id,
        "predictor_i": i,
        "predictor": basis.label.strip(),
        "basis": basis.__class__.__name__,
        # any other info you think is important ...
    }
    for i, basis in enumerate(additive_basis)
]
df = pd.DataFrame(data)

df