Spike Sorting Tutorial

This tutorial will guide you through creating spike sorting visualizations using figpack’s specialized spike sorting views.

Prerequisites

Before starting this tutorial, make sure you have the required packages installed:

pip install figpack figpack_spike_sorting spikeinterface

Units Table

The Units Table displays basic information about detected neural units:

import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic spike data
recording, sorting = se.toy_example(
    num_units=6, duration=60, seed=0, num_segments=1
)

# Create a simple units table
columns = [
    ssv.UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
    ssv.UnitsTableColumn(key="numSpikes", label="Spike Count", dtype="int"),
]

rows = []
for unit_id in sorting.get_unit_ids():
    spike_count = len(sorting.get_unit_spike_train(unit_id=unit_id))
    rows.append(
        ssv.UnitsTableRow(
            unit_id=unit_id,
            values={
                "unitId": unit_id,
                "numSpikes": spike_count,
            },
        )
    )

view = ssv.UnitsTable(columns=columns, rows=rows)
view.show(title="Units Table Example", open_in_browser=True)

Unit Metrics Graph

The Unit Metrics Graph provides interactive visualization of unit metrics, allowing you to analyze relationships between different properties of units:

import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic spike data
recording, sorting = se.toy_example(
    num_units=12, duration=300, seed=0, num_segments=1
)

# Define metrics to analyze
metrics = [
    ssv.UnitMetricsGraphMetric(key="numEvents", label="Num. events", dtype="int"),
    ssv.UnitMetricsGraphMetric(key="firingRateHz", label="Firing rate (Hz)", dtype="float"),
]

# Calculate metrics for each unit
units = []
for unit_id in sorting.get_unit_ids():
    spike_train = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id)
    units.append(
        ssv.UnitMetricsGraphUnit(
            unit_id=unit_id,
            values={
                "numEvents": len(spike_train),
                "firingRateHz": len(spike_train) / (recording.get_num_frames(segment_index=0) / recording.get_sampling_frequency())
            }
        )
    )

# Create and show the view
view = ssv.UnitMetricsGraph(units=units, metrics=metrics, height=500)
view.show(title="Unit Metrics Graph Example", open_in_browser=True)

This creates an interactive view where you can:

  • View histograms of individual metrics

  • Compare metrics in scatter plots

  • Select which metrics to display

  • Adjust histogram bin sizes

  • Zoom into regions of interest

  • Select and highlight units

Raster Plot

A raster plot shows when each unit fired spikes over time:

import numpy as np
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=8, duration=30, seed=0, num_segments=1
)

# Create raster plot items
plot_items = []
for unit_id in sorting.get_unit_ids():
    spike_times_sec = (
        np.array(sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id))
        / sorting.get_sampling_frequency()
    )
    plot_items.append(
        ssv.RasterPlotItem(
            unit_id=unit_id,
            spike_times_sec=spike_times_sec.astype(np.float32)
        )
    )

view = ssv.RasterPlot(
    start_time_sec=0,
    end_time_sec=30,
    plots=plot_items,
)

view.show(title="Raster Plot Example", open_in_browser=True)

Here’s an example using real data from the DANDI Archive:

Spike Amplitudes

Visualize spike amplitudes over time to assess unit stability:

import numpy as np
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=5, duration=60, seed=0, num_segments=1
)

# Create amplitude plots with simulated data
plot_items = []
rng = np.random.default_rng(42)

for unit_id in sorting.get_unit_ids():
    spike_times_sec = (
        np.array(sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id))
        / sorting.get_sampling_frequency()
    )

    # Simulate realistic amplitude data
    base_amplitude = rng.uniform(50, 200)
    amplitudes = base_amplitude + rng.normal(0, 10, len(spike_times_sec))

    plot_items.append(
        ssv.SpikeAmplitudesItem(
            unit_id=unit_id,
            spike_times_sec=spike_times_sec.astype(np.float32),
            spike_amplitudes=amplitudes.astype(np.float32),
        )
    )

view = ssv.SpikeAmplitudes(
    start_time_sec=0,
    end_time_sec=60,
    plots=plot_items,
)

view.show(title="Spike Amplitudes Example", open_in_browser=True)

Or load it from a local or remote NWB file:

import figpack_spike_sorting.views as ssv

print(f"Loading from remote NWB file...")
view = ssv.SpikeAmplitudes.from_nwb_units_table(
    "https://api.dandiarchive.org/api/assets/37ca1798-b14c-4224-b8f0-037e27725336/download/",
    units_path="/units",
    include_units_selector=True,
)
view.show(title="NWB Spike Amplitudes Example")

Average Waveforms

Average waveforms show the average spike shape for each unit:

import spikeinterface as si
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data with ground truth
recording, sorting = si.generate_ground_truth_recording(
    durations=[120],
    num_units=8,
    seed=0,
    num_channels=6,
    noise_kwargs={"noise_levels": 50},
)

# Create sorting analyzer to compute waveforms
sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording)

# Create average waveforms view
view = ssv.AverageWaveforms.from_sorting_analyzer(sorting_analyzer)
view.show(title="Average Waveforms Example", open_in_browser=True)

This visualization allows you to:

  • View the average spike waveform for each unit

  • See waveforms across all recording channels

  • Assess the spatial extent and amplitude of each unit

  • Identify the primary channel for each unit

  • Compare waveform shapes between units

Autocorrelograms

Autocorrelograms help assess the quality of unit isolation by showing refractory periods:

import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=6, duration=120, seed=0, num_segments=1
)

# Use the built-in method for simplicity
view = ssv.Autocorrelograms.from_sorting(sorting)
view.show(title="Autocorrelograms Example", open_in_browser=True)

Unit Locations

The Unit Locations view shows the spatial arrangement of units and recording channels on the probe.

In this example, we generate synthetic data and simulate unit locations for demonstration:

from typing import List
import numpy as np
import spikeinterface as si
import figpack_spike_sorting.views as ssv

# Generate synthetic data with ground truth
recording, sorting = si.generate_ground_truth_recording(
    durations=[120],
    num_units=10,
    seed=0,
    num_channels=8,
    noise_kwargs={"noise_levels": 50},
)

# Create unit locations view
channel_locations = recording.get_channel_locations()
xmin = np.min(channel_locations[:, 0])
xmax = np.max(channel_locations[:, 0])
if xmax <= xmin:
    xmax = xmin + 1
ymin = np.min(channel_locations[:, 1])
ymax = np.max(channel_locations[:, 1])
if ymax <= ymin:
    ymax = ymin + 1

unit_ids = sorting.get_unit_ids()
unit_items: List[ssv.UnitLocationsItem] = []
for ii, unit_id in enumerate(unit_ids):
    unit_items.append(
        ssv.UnitLocationsItem(
            unit_id=unit_id,
            x=float(xmin + ((ii + 0.5) / len(unit_ids)) * (xmax - xmin)),
            y=float(ymin + ((ii + 0.5) / len(unit_ids)) * (ymax - ymin)),  # simulated location
        )
    )

channel_locations_dict = {}
for ii, channel_id in enumerate(recording.channel_ids):
    channel_locations_dict[str(channel_id)] = recording.get_channel_locations()[
        ii, :
    ].astype(np.float32)

view = ssv.UnitLocations(
    units=unit_items,
    channel_locations=channel_locations_dict,
    disable_auto_rotate=True
)

view.show(title="Unit Locations Example", open_in_browser=True)

Cross Correlograms

Cross correlograms reveal temporal relationships between different units, helping identify potential synchrony or interactions:

import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data with more units for cross-correlation analysis
recording, sorting = se.toy_example(
    num_units=9, duration=300, seed=0, num_segments=1
)

# Use the built-in method to create cross correlograms
view = ssv.CrossCorrelograms.from_sorting(sorting)
view.show(title="Cross Correlograms Example", open_in_browser=True)

Unit Similarity Matrix

The Unit Similarity Matrix displays pairwise similarity scores between units, helping identify potentially redundant units or assess the overall quality of unit separation:

from typing import List
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=12, duration=300, seed=0, num_segments=1
)

unit_ids = list(sorting.get_unit_ids())

# Create similarity scores between all pairs of units
# In practice, these would be computed from actual waveform comparisons
similarity_scores: List[ssv.UnitSimilarityScore] = []
for i, u1 in enumerate(unit_ids):
    for j, u2 in enumerate(unit_ids):
        # Example: fake similarity score for demonstration
        # In real use, compute from waveform correlations, template matching, etc.
        # Using indices to create a numeric similarity metric
        similarity = 1 - abs(i - j) / (i + j + 1)
        similarity_scores.append(
            ssv.UnitSimilarityScore(unit_id1=u1, unit_id2=u2, similarity=similarity)
        )

# Create the view with optional range parameter for color scaling
view = ssv.UnitSimilarityMatrix(
    unit_ids=unit_ids,
    similarity_scores=similarity_scores,
    range=(0, 1)  # Optional: specify min/max values for color scale
)

view.show(title="Unit Similarity Matrix Example", open_in_browser=True)

This visualization allows you to:

  • View similarity scores in a heatmap format

  • Identify highly similar units that may need merging

  • Select units to highlight their similarity relationships

  • Use the “Select Similar” button to quickly select the most similar units to the current selection

Spike Locations

The Spike Locations view displays the spatial distribution of individual spike events in 2D space, useful for visualizing spike localization from multi-electrode arrays:

from typing import List
import numpy as np
import spikeinterface.extractors as se
from figpack_spike_sorting.views import SpikeLocations, SpikeLocationsItem

# Generate synthetic spike data
recording, sorting = se.toy_example(
    num_units=12, num_channels=10, duration=300, seed=0
)

# Get channel locations for the probe
channel_locations = recording.get_channel_locations().astype(np.float32)
xmin = np.min(channel_locations[:, 0])
xmax = np.max(channel_locations[:, 0])
ymin = np.min(channel_locations[:, 1])
ymax = np.max(channel_locations[:, 1])

# Expand ranges for better visualization
xspan = xmax - xmin
yspan = ymax - ymin
if xmax <= xmin:
    xmin = xmin - 12
    xmax = xmax + 12
if ymax <= ymin:
    ymin = ymin - 12
    ymax = ymax + 12

xspan = xmax - xmin
yspan = ymax - ymin
xmin = xmin - xspan * 0.2
xmax = xmax + xspan * 0.2
ymin = ymin - yspan * 0.2
ymax = ymax + yspan * 0.2

# Create spike location items for each unit
rng = np.random.default_rng(2022)
items: List[SpikeLocationsItem] = []
for unit_id in sorting.get_unit_ids():
    spike_times_sec = (
        np.array(sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id))
        / sorting.get_sampling_frequency()
    )
    
    # Simulate a center location for this unit
    center_x = rng.uniform(xmin, xmax)
    center_y = rng.uniform(ymin, ymax)
    
    # Create simulated spike locations around the center
    items.append(
        SpikeLocationsItem(
            unit_id=unit_id,
            spike_times_sec=spike_times_sec.astype(np.float32),
            x_locations=rng.normal(center_x, 6, spike_times_sec.shape).astype(
                np.float32
            ),
            y_locations=rng.normal(center_y, 6, spike_times_sec.shape).astype(
                np.float32
            ),
        )
    )

# Prepare channel locations dictionary
channel_locations_dict = {}
for ii, channel_id in enumerate(recording.get_channel_ids()):
    channel_locations_dict[str(channel_id)] = recording.get_channel_locations()[
        ii, :
    ].astype(np.float32)

# Create the view
view = SpikeLocations(
    units=items,
    x_range=(float(xmin), float(xmax)),
    y_range=(float(ymin), float(ymax)),
    channel_locations=channel_locations_dict,
    disable_auto_rotate=True,
)

view.show(title="Spike Locations Example", open_in_browser=True)

This visualization allows you to:

  • View the 2D spatial distribution of spike events

  • Visualize how spikes cluster in space for each unit

  • See channel locations overlaid on the spike scatter plot

  • Select and highlight specific units

  • Use the “Only show selected” option to focus on particular units

Simple Combined Layout

Combine multiple views:

import numpy as np
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv
import figpack.views as vv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=6, duration=60, seed=0, num_segments=1
)

# Create units table
columns = [
    ssv.UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
    ssv.UnitsTableColumn(key="numSpikes", label="Spikes", dtype="int"),
]

rows = []
for unit_id in sorting.get_unit_ids():
    spike_count = len(sorting.get_unit_spike_train(unit_id=unit_id))
    rows.append(
        ssv.UnitsTableRow(
            unit_id=unit_id,
            values={"unitId": unit_id, "numSpikes": spike_count},
        )
    )

units_table = ssv.UnitsTable(columns=columns, rows=rows)

# Create raster plot
plot_items = []
for unit_id in sorting.get_unit_ids():
    spike_times_sec = (
        np.array(sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id))
        / sorting.get_sampling_frequency()
    )
    plot_items.append(
        ssv.RasterPlotItem(
            unit_id=unit_id,
            spike_times_sec=spike_times_sec.astype(np.float32)
        )
    )

raster_plot = ssv.RasterPlot(
    start_time_sec=0,
    end_time_sec=60,
    plots=plot_items,
)

# Combine in a layout
view = vv.Splitter(
    direction="horizontal",
    item1=vv.LayoutItem(view=units_table, max_size=300, title="Units"),
    item2=vv.LayoutItem(view=raster_plot, title="Spike Times"),
    split_pos=0.3,
)

view.show(title="Spike Sorting Dashboard", open_in_browser=True)

Sorting Curation

The Sorting Curation view provides an interactive interface for manually curating spike sorting results. This allows you to label units as “good”, “noise”, “mua” (multi-unit activity), or other custom labels to assess the quality of the sorting:

from typing import List
import spikeinterface.extractors as se
import figpack_spike_sorting.views as ssv
import figpack.views as vv

# Generate synthetic data
recording, sorting = se.toy_example(
    num_units=18, duration=300, seed=0, num_segments=1
)

# Create a units table for the left side
columns: List[ssv.UnitsTableColumn] = [
    ssv.UnitsTableColumn(key="unitId", label="Unit", dtype="int"),
]
rows: List[ssv.UnitsTableRow] = []
for unit_id in sorting.get_unit_ids():
    rows.append(
        ssv.UnitsTableRow(
            unit_id=unit_id,
            values={
                "unitId": unit_id,
            },
        )
    )

units_table = ssv.UnitsTable(
    columns=columns,
    rows=rows,
)

# Create autocorrelograms for the right side
autocorrelograms = ssv.Autocorrelograms.from_sorting(sorting)

# Create the sorting curation view
curation_view = ssv.SortingCuration(default_label_options=["mua", "good", "noise"])

# Combine views in a layout
left_panel = vv.Box(
    direction="vertical",
    items=[
        vv.LayoutItem(view=units_table, title="Units Table"),
        vv.LayoutItem(view=curation_view, title="Sorting Curation"),
    ],
)

# Create splitter with units table and curation on left, autocorrelograms on right
view = vv.Splitter(
    direction="horizontal",
    item1=vv.LayoutItem(view=left_panel, max_size=800, title="Units"),
    item2=vv.LayoutItem(view=autocorrelograms, title="Autocorrelograms"),
    split_pos=0.25,  # 25% for the left panel, 75% for autocorrelograms
)

view.show(title="Sorting Curation Example", open_in_browser=True)