Selection of a timestep for SNN simulation
What is the proper timestep to select when simulating a spiking neural network? The answer is, of course, it depends. Although, I think the usual assumption is incorrect when using leaky-integrate-and-fire neurons. Here's why!
Context
I recently came across a tweet by Dan Goodman, which presented a brief experiment demonstrating the detrimental effects of using a large timestep () during the simulation of a LIF neuron. The output spiking rate of a LIF neuron with a Poisson spike input was found to decrease as the timestep increased, with failure observed at soon as ms, a standard timestep size within the CS-oriented community.
Of course, there is a direct relationship between the choice of , and real-world simulation duration (or wall-clock time). Ideally, we would all be using a very large for our simulation. As Guillaume Bellec pointed out, there might not even be any advantage in a machine learning setting to using a small .
Rather than accepting the necessity of a small timestep, it is worth investigating why the simulation fails, even when employing an exact solver instead of Euler’s method. Specifically, there should only be a small distinction when the spike arrives at the beginning, or the end, of a clock cycle. We somewhat over or underestimate the membrane potential by depending on when the spike arrived during the clock period.
Recreating the Simulation
A straightforward experiment can be devised to replicate the behavior outlined in the tweet. We will simulate 100 LIF neurons, being stimulated by 100 Poisson spike trains sampled at 5 Hz for 4 seconds. The LIF’s time constant is ms. The weights between the 100 inputs and 100 output neurons are randomly sampled from a normal distribution . We then compute the mean output firing rate of every output neuron, and the corresponding standard deviation as error bars.
import numpy as np
import plotly.graph_objects as go
np.random.seed(0x1B)
duration = 4 # seconds
tau = 0.010
thresh = 1
nb_inputs = 100
nb_outputs = 1000
input_rate = 5 # Hz
weights = np.random.randn(nb_outputs, nb_inputs) * 0.5 + 0.1
dts = np.logspace(-5, -1.5, 10) # in seconds
spike_rates = np.zeros((len(dts), nb_outputs))
for i, dt in enumerate(dts):
time = np.arange(0, duration, dt)
u = np.zeros(nb_outputs)
_exp = np.exp(-dt / tau)
input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
weighted_input_spikes = input_spikes @ weights.T
spike_count = 0
for j, t in enumerate(time):
u = _exp * u + weighted_input_spikes[j]
spikes = u > thresh
spike_count += spikes
u[spikes] = 0 # reset
spike_rates[i] += spike_count / duration
fig = go.Figure(go.Scatter(
x=dts * 1000, y=spike_rates.mean(axis=1),
error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
mode='lines+markers',
))
fig.update_layout(
xaxis=dict(title='δt [ms]', type='log'),
yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
We arrive at a similar-looking plot, where the output spiking frequency is going down near ms.
Hypothesis
Numerous commenters in the original thread suggested that should be chosen to match . Of course, there is some influence of the chosen time constant , as the smaller the leakage during a timestep, the smaller the error of membrane potential that can happen. However, I am skeptical of this notion due to the stochastic nature of Poisson spikes. Given that a spike can occur at any time during a timestep, it seems likely that the overestimation of membrane potential will roughly cancel out the underestimation.
My hypothesis differs from this perspective. I contend that the real difference is elsewhere. Specifically, owing to the nature of the simulation, a neuron can only emit a single spike within a given timestep. Consequently, the LIF neuron enters a sort of implicit refractory period for the duration of the timestep. When the timestep is exceedingly large (greater than 1 ms in this instance), the neuron experiences a prolonged refractory period, leading to the potential loss of important input spikes as it is unable to integrate new input during this interval.
If the assumption is correct — i.e. the timestep is forcing an implicit refractory period — then having a large refractory period but with a smaller should yield the same result as having a larger . If we add a refractory period to the experiment above, we’ll see that they do indeed provide a similar effect:
fig = go.Figure()
for refractory_period in [0.001, 0.01, 0.1]:
spike_rates = np.zeros((len(dts), nb_outputs))
for i, dt in enumerate(dts):
time = np.arange(0, duration, dt)
refrac_clk = int(refractory_period / dt)
u = np.zeros(nb_outputs)
refrac_cntr = np.zeros(nb_outputs, dtype=int)
_exp = np.exp(-dt / tau)
input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
weighted_input_spikes = input_spikes @ weights.T
spike_count = 0
for j, t in enumerate(time):
non_refrac = refrac_cntr == 0
u[non_refrac] = _exp * u[non_refrac] + weighted_input_spikes[j, non_refrac]
spikes = u > thresh
spike_count += spikes
u[spikes] = 0
refrac_cntr = np.maximum(refrac_cntr - 1, 0)
refrac_cntr[spikes] += refrac_clk
spike_rates[i] += spike_count / duration
fig.add_trace(go.Scatter(
x=dts * 1000, y=spike_rates.mean(axis=1),
error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
mode='lines+markers',
name=f"Refrac.: {1000 * refractory_period:.1f}ms",
))
fig.update_layout(
xaxis=dict(title='δt [ms]', type='log'),
yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
As we see, the output firing rates align when is equal to the refractory period. Therefore, the model is actually correct — the only difference is that we have to consider that the effective refractory period is equal to the maximum between and the explicit refractory period.
Solution
The solution is quite simple. The timestep forces an implicit refractory period because the neuron can only spike once per timestep. If we remove this limitation, the implicit refractory period disappears and the output firing rate should be constant regardless of the timestep.
To do so, we count the number of times the membrane potential is above the threshold to estimate how many times the neuron would spike in one timestep:
We also edit the reset so that we subtract the threshold times — a soft reset. This is more precise when dealing with large timesteps, as accumulated membrane potential is not wasted by an early spike during a timestep.
np.random.seed(0x1B)
spike_rates = np.zeros((len(dts), nb_outputs))
for i, dt in enumerate(dts):
time = np.arange(0, duration, dt)
u = np.zeros(nb_outputs)
_exp = np.exp(-dt / tau)
input_spikes = np.random.poisson(lam=input_rate * dt, size=(len(time), nb_inputs))
weighted_input_spikes = input_spikes @ weights.T
spike_count = 0
for j, t in enumerate(time):
u = _exp * u + weighted_input_spikes[j]
spikes = np.floor(np.maximum(u, 0) / thresh) # multiple spikes per timestep
spike_count += spikes
u -= spikes * thresh # soft reset
spike_rates[i] += spike_count / duration
fig = go.Figure(go.Scatter(
x=dts * 1000, y=spike_rates.mean(axis=1),
error_y=dict(type='data', array=spike_rates.std(axis=1), visible=True),
mode='lines+markers',
))
fig.update_layout(
xaxis=dict(title='δt [ms]', type='log'),
yaxis=dict(title='Output firing rate [sp/s]'),
)
fig.show()
And voilà! We get the expected firing rate across all the timesteps. While this solution is very interesting for computational neuroscientists, it partly removes the energy friendliness of spiking neural networks since they are no longer binary, and the reset involves some arithmetic.