Selection of a timestep for SNN simulation

What is the proper timestep to select when simulating a spiking neural network? The answer is, of course, it depends. Although, I think the usual assumption is incorrect when using leaky-integrate-and-fire neurons. Here's why!


Context

I recently came across a tweet by Dan Goodman, which presented a brief experiment demonstrating the detrimental effects of using a large timestep (δt\delta t) during the simulation of a LIF neuron. The output spiking rate of a LIF neuron with a Poisson spike input was found to decrease as the timestep increased, with failure observed at soon as δt=1\delta t=1 ms, a standard timestep size within the CS-oriented community.

Of course, there is a direct relationship between the choice of δt\delta t, and real-world simulation duration (or wall-clock time). Ideally, we would all be using a very large δt\delta t for our simulation. As Guillaume Bellec pointed out, there might not even be any advantage in a machine learning setting to using a small δt\delta t.

Rather than accepting the necessity of a small timestep, it is worth investigating why the simulation fails, even when employing an exact solver instead of Euler’s method. Specifically, there should only be a small distinction when the spike arrives at the beginning, or the end, of a clock cycle. We somewhat over or underestimate the membrane potential by wexp(δtτ)w\exp(\frac{-\delta t}{\tau}) depending on when the spike arrived during the clock period.

Recreating the Simulation

A straightforward experiment can be devised to replicate the behavior outlined in the tweet. We will simulate 100 LIF neurons, being stimulated by 100 Poisson spike trains sampled at 5 Hz for 4 seconds. The LIF’s time constant is τ=10\tau=10 ms. The weights between the 100 inputs and 100 output neurons are randomly sampled from a normal distribution N(0.1,0.25)\mathcal{N}(0.1, 0.25). We then compute the mean output firing rate of every output neuron, and the corresponding standard deviation as error bars.

import numpy as np
import plotly.graph_objects as go

np.random.seed(0x1B)
duration = 4  # seconds
tau = 0.010
thresh = 1
nb_inputs = 100
nb_outputs = 1000
input_rate = 5  # Hz
weights = np.random.randn(nb_outputs, nb_inputs) * 0.5 + 0.1
dts = np.logspace(-5, -1.5, 10)  # in seconds

spike_rates = np.zeros((len(dts), nb_outputs))
for i, dt in enumerate(dts):
    time = np.arange(0, duration, dt)
    u = np.zeros(nb_outputs)
    _exp = np.exp(-dt / tau)
    input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
    weighted_input_spikes = input_spikes @ weights.T
    spike_count = 0

    for j, t in enumerate(time):
        u = _exp * u + weighted_input_spikes[j]
        spikes = u > thresh
        spike_count += spikes
        u[spikes] = 0  # reset
    spike_rates[i] += spike_count / duration

fig = go.Figure(go.Scatter(
    x=dts * 1000, y=spike_rates.mean(axis=1),
    error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
    mode='lines+markers',
))
fig.update_layout(
    xaxis=dict(title='δt [ms]', type='log'),
    yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
Output firing rate vs. δt — the spiking frequency drops near δt = 1 ms.

We arrive at a similar-looking plot, where the output spiking frequency is going down near δt=1\delta t=1 ms.

Hypothesis

Numerous commenters in the original thread suggested that δt\delta t should be chosen to match τ\tau. Of course, there is some influence of the chosen time constant τ\tau, as the smaller the leakage during a timestep, the smaller the error of membrane potential that can happen. However, I am skeptical of this notion due to the stochastic nature of Poisson spikes. Given that a spike can occur at any time during a timestep, it seems likely that the overestimation of membrane potential will roughly cancel out the underestimation.

My hypothesis differs from this perspective. I contend that the real difference is elsewhere. Specifically, owing to the nature of the simulation, a neuron can only emit a single spike within a given timestep. Consequently, the LIF neuron enters a sort of implicit refractory period for the duration of the timestep. When the timestep is exceedingly large (greater than 1 ms in this instance), the neuron experiences a prolonged refractory period, leading to the potential loss of important input spikes as it is unable to integrate new input during this interval.

If the assumption is correct — i.e. the timestep δt\delta t is forcing an implicit refractory period — then having a large refractory period but with a smaller δt\delta t should yield the same result as having a larger δt\delta t. If we add a refractory period to the experiment above, we’ll see that they do indeed provide a similar effect:

fig = go.Figure()

for refractory_period in [0.001, 0.01, 0.1]:
    spike_rates = np.zeros((len(dts), nb_outputs))
    for i, dt in enumerate(dts):
        time = np.arange(0, duration, dt)
        refrac_clk = int(refractory_period / dt)
        u = np.zeros(nb_outputs)
        refrac_cntr = np.zeros(nb_outputs, dtype=int)
        _exp = np.exp(-dt / tau)
        input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
        weighted_input_spikes = input_spikes @ weights.T
        spike_count = 0

        for j, t in enumerate(time):
            non_refrac = refrac_cntr == 0
            u[non_refrac] = _exp * u[non_refrac] + weighted_input_spikes[j, non_refrac]
            spikes = u > thresh
            spike_count += spikes
            u[spikes] = 0
            refrac_cntr = np.maximum(refrac_cntr - 1, 0)
            refrac_cntr[spikes] += refrac_clk

        spike_rates[i] += spike_count / duration

    fig.add_trace(go.Scatter(
        x=dts * 1000, y=spike_rates.mean(axis=1),
        error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
        mode='lines+markers',
        name=f"Refrac.: {1000 * refractory_period:.1f}ms",
    ))

fig.update_layout(
    xaxis=dict(title='δt [ms]', type='log'),
    yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
Adding explicit refractory periods produces the same degradation pattern as large δt.

As we see, the output firing rates align when δt\delta t is equal to the refractory period. Therefore, the model is actually correct — the only difference is that we have to consider that the effective refractory period is equal to the maximum between δt\delta t and the explicit refractory period.

Solution

The solution is quite simple. The timestep forces an implicit refractory period because the neuron can only spike once per timestep. If we remove this limitation, the implicit refractory period disappears and the output firing rate should be constant regardless of the timestep.

To do so, we count the number of times the membrane potential u(t)u(t) is above the threshold to estimate how many times the neuron would spike in one timestep:

nspikes(t)=max{u(t),0}uthreshn_{\text{spikes}}(t) = \left\lfloor \frac{\max\{u(t),\, 0\}}{u_{\text{thresh}}} \right\rfloor

We also edit the reset so that we subtract the threshold nspikesn_{\text{spikes}} times — a soft reset. This is more precise when dealing with large timesteps, as accumulated membrane potential is not wasted by an early spike during a timestep.

np.random.seed(0x1B)
spike_rates = np.zeros((len(dts), nb_outputs))
for i, dt in enumerate(dts):
    time = np.arange(0, duration, dt)
    u = np.zeros(nb_outputs)
    _exp = np.exp(-dt / tau)
    input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
    weighted_input_spikes = input_spikes @ weights.T
    spike_count = 0

    for j, t in enumerate(time):
        u = _exp * u + weighted_input_spikes[j]
        spikes = np.floor(np.maximum(u, 0) / thresh)  # multiple spikes per timestep
        spike_count += spikes
        u -= spikes * thresh  # soft reset

    spike_rates[i] += spike_count / duration

fig = go.Figure(go.Scatter(
    x=dts * 1000, y=spike_rates.mean(axis=1),
    error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
    mode='lines+markers',
))
fig.update_layout(
    xaxis=dict(title='δt [ms]', type='log'),
    yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
With soft reset and multiple spikes per timestep, the firing rate is stable across all δt values.

And voilà! We get the expected firing rate across all the timesteps. While this solution is very interesting for computational neuroscientists, it partly removes the energy friendliness of spiking neural networks since they are no longer binary, and the reset involves some arithmetic.