Basic concept: Timestep
----

The code is run on spikingjelly 0.0.0.13 (Nov 2022). You may want to specify the version to ensure the compatibilities


## 1. The first concept is to run an activation function

In SNN, the activation function is usually a "comparison" between current membrane potential (v) and the threshold (v_thh)

In [1]:
import torch

v = torch.rand([8])
v_th = 0.5
spike = (v >= v_th).to(v)
print('spike =', spike)
# spike = tensor([0., 0., 0., 1., 1., 0., 1., 0.])

spike = tensor([0., 1., 1., 1., 0., 1., 1., 1.])


## 2. Step mode

There are two step mode:

1. 'm': multiple timesteps
2. 's': single timestep

In [2]:
import torch
from spikingjelly.activation_based import neuron

net = neuron.IFNode(step_mode='m')
# 'm' is the multi-step mode
net.step_mode = 's'
# 's' is the single-step mode


## 3. Data format

In spikingjelly.activation_based, There are two formats of data:

1. Data in a single time-step (note: step_mod = 's') with shape = \[N,*\], where N is the batch dimension, * represents any extra dimensions.

2. Data in many time-steps (note: step_mod = 'm')  with shape = \[T, N, *\], where T is the time-step dimension, N is the batch dimension and * represents any additional dimensions.


In [3]:
import torch
from spikingjelly.activation_based import neuron

net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
for t in range(T):
    x = x_seq[t]  # x.shape = [N, C, H, W]
    y = net_s(x)  # y.shape = [N, C, H, W]
    y_seq.append(y.unsqueeze(0))


y_seq = torch.cat(y_seq)
# y_seq.shape = [T, N, C, H, W]

print("shape of x: "+str(x.shape))
print("shape of y: "+str(y.shape))
print("shape of y_seq: "+str(y_seq.shape))

shape of x: torch.Size([1, 3, 8, 8])
shape of y: torch.Size([1, 3, 8, 8])
shape of y_seq: torch.Size([4, 1, 3, 8, 8])


In the previous example, we have a:
1. ```net_s```: a simple IF node that receives the weighted spikes (x) and results spikes (y). Note that we are in single time-step mode (step_mode=`s`)
2. T = 4 is the number of time-steps
3. N = 1 is the batch size
4. \[C, H, W \] is the shape of a single time-step input

As we run in a loop of t in range(T), we feed x into the net_s and get y.





In [4]:
print(x) #last value of x

tensor([[[[0.5071, 0.7052, 0.4534, 0.5574, 0.8354, 0.0320, 0.3970, 0.0040],
          [0.0341, 0.7944, 0.1181, 0.4376, 0.3779, 0.8777, 0.2491, 0.6237],
          [0.7391, 0.2109, 0.6843, 0.0922, 0.4819, 0.7587, 0.9131, 0.4688],
          [0.8240, 0.9039, 0.1107, 0.3576, 0.0064, 0.5205, 0.3281, 0.1259],
          [0.7636, 0.7522, 0.2319, 0.6788, 0.3135, 0.5233, 0.9580, 0.0841],
          [0.3263, 0.9880, 0.5973, 0.7657, 0.2332, 0.7054, 0.0202, 0.4270],
          [0.1711, 0.9571, 0.7265, 0.7625, 0.2239, 0.9089, 0.0328, 0.9163],
          [0.5642, 0.7324, 0.1692, 0.1965, 0.4641, 0.4657, 0.4648, 0.4485]],

         [[0.6333, 0.1055, 0.9311, 0.8106, 0.6487, 0.7959, 0.9775, 0.3199],
          [0.5284, 0.5761, 0.0359, 0.6869, 0.3350, 0.2341, 0.4242, 0.7159],
          [0.6498, 0.2764, 0.0762, 0.5623, 0.9301, 0.2080, 0.4853, 0.6036],
          [0.5455, 0.4864, 0.0172, 0.0084, 0.9250, 0.8343, 0.4295, 0.2044],
          [0.3006, 0.5396, 0.0617, 0.4584, 0.4168, 0.3356, 0.2164, 0.1618],
          

In [5]:
print(y) # last value of y

tensor([[[[0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 1., 0., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1., 1., 0., 0.],
          [0., 1., 1., 0., 1., 1., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.]],

         [[1., 0., 1., 0., 0., 1., 1., 0.],
          [1., 0., 0., 0., 0., 0., 0., 1.],
          [1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 1., 1., 0., 0.],
          [1., 1., 0., 0., 1., 1., 1., 1.],
          [1., 0., 0., 1., 0., 1., 1., 0.],
          [0., 0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 0., 0., 0., 1., 1.]],

         [[0., 1., 0., 0., 0., 0., 0., 0.],
          [1., 0., 0., 0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 1., 1., 1., 1.],
          [0., 0., 1., 0., 1

## 3. Multiple time-step

With the above example, we can rewrite to run multiple timesteps

The first way to write is to use functional.multi_step_forward


In [6]:
import torch
from spikingjelly.activation_based import neuron, functional
net_s = neuron.IFNode(step_mode='s')

y_seq = functional.multi_step_forward(x_seq, net_s)

In [7]:
print(y_seq[-1]) # last value of y

tensor([[[[0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 1., 0., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1., 1., 0., 0.],
          [0., 1., 1., 0., 1., 1., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.]],

         [[1., 0., 1., 0., 0., 1., 1., 0.],
          [1., 0., 0., 0., 0., 0., 0., 1.],
          [1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 1., 1., 0., 0.],
          [1., 1., 0., 0., 1., 1., 1., 1.],
          [1., 0., 0., 1., 0., 1., 1., 0.],
          [0., 0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 0., 0., 0., 1., 1.]],

         [[0., 1., 0., 0., 0., 0., 0., 0.],
          [1., 0., 0., 0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 1., 1., 1., 1.],
          [0., 0., 1., 0., 1

The second way is to use multple time-step modes 

In [8]:
import torch
from spikingjelly.activation_based import neuron

net_m = neuron.IFNode(step_mode='m')

y_seq = net_m(x_seq)
# y_seq.shape = [T, N, C, H, W]

In [9]:
print(y_seq[-1]) # last value of y

tensor([[[[0., 0., 0., 1., 1., 0., 0., 0.],
          [0., 0., 0., 1., 0., 1., 0., 0.],
          [0., 0., 1., 1., 1., 0., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0., 0., 1., 0.],
          [0., 1., 1., 1., 1., 1., 0., 0.],
          [0., 1., 1., 0., 1., 1., 1., 1.],
          [0., 1., 0., 0., 0., 0., 0., 0.]],

         [[1., 0., 1., 0., 0., 1., 1., 0.],
          [1., 0., 0., 0., 0., 0., 0., 1.],
          [1., 0., 0., 0., 1., 0., 1., 0.],
          [0., 0., 0., 0., 1., 1., 0., 0.],
          [1., 1., 0., 0., 1., 1., 1., 1.],
          [1., 0., 0., 1., 0., 1., 1., 0.],
          [0., 0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 0., 0., 0., 1., 1.]],

         [[0., 1., 0., 0., 0., 0., 0., 0.],
          [1., 0., 0., 0., 1., 0., 0., 0.],
          [1., 0., 0., 0., 0., 1., 0., 1.],
          [1., 1., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 1., 1., 1., 1.],
          [0., 0., 1., 0., 1