Training a Simple Spiking Neural Network (SNN)#

Binder

This notebook introduces end-to-end training of a simple SNN with Aidge. You will build a small network, connect it to an MNIST input pipeline, define optimization components, and run a short training loop.

What you will learn:

  1. How to create a minimal SNN graph with Aidge operators

  2. How to load and preprocess MNIST for batched training

  3. How to configure optimizer and scheduler objects

  4. How to run a timestep-aware training loop and monitor the loss

Install Requirements#

Required Aidge packages: aidge_core, aidge_backend_cpu, and aidge_learning

Required data utilities: torch and torchvision (used here to build MNIST dataloaders).

If you are using a local development environment, prefer package versions that match your source checkout to avoid API mismatches.

[ ]:
%pip install aidge-core \
    aidge-backend-cpu \
    aidge-learning

%pip install torch torchvision

Next, import the modules used throughout the tutorial.

The import block includes:

  • Aidge core, backend, and learning APIs

  • NumPy for array conversions

  • Torch and TorchVision utilities for dataset loading

  • Basic runtime configuration to reduce log verbosity

[ ]:
import aidge_core
import aidge_backend_cpu
import aidge_learning

import numpy as np

# Required to load MNIST dataset
import torchvision
import torchvision.transforms as transforms

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch

# Changing Log level to avoid too much logs
aidge_core.Log.set_console_level(aidge_core.Level.Fatal)

Create the Aidge Model#

Architecture used in this tutorial: Input -> FC -> LIF -> FC -> LIF -> Stack

In this example, we will create a simple perceptron model (One FC (Fully Connected) layer followed by one LIF (Leaky Integrate-And-Fire) layer followed by one FC layer followed by one LIF layer).

First, define the model constants and SNN hyperparameters.

Important values in this tutorial:

  • num_steps: number of simulation timesteps per input sample

  • in_channels, hidden_channels, out_channels: layer dimensions

  • beta, threshold, reset_type: leaky neuron dynamics

  • batch_size: number of samples processed per iteration

[ ]:
num_steps = 10

# Network topology
in_channels = 28 * 28  # H*W of MNIST
hidden_channels = 1000
out_channels = 10

# SNN parameters
beta = 0.95
threshold = 1.0

batch_size = 64

reset_type = aidge_core.leaky_reset.subtraction  # Type of reset for leaky

Then build the graph and initialize trainable parameters.

This block creates each node, connects operators in execution order, compiles the graph on CPU, and initializes FC weights with He initialization.

[ ]:
# Creation of the network
stack = aidge_core.Stack(num_steps, name="stack")
pop = aidge_core.Pop(name="pop")
fc1 = aidge_core.FC(in_channels, hidden_channels, name="fc1")
lif1 = aidge_core.Leaky(num_steps + 1, beta, threshold, reset_type, False, name="lif1")
fc2 = aidge_core.FC(hidden_channels, out_channels, name="fc2")
lif2 = aidge_core.Leaky(num_steps + 1, beta, threshold, reset_type, True, name="lif2")

model = aidge_core.GraphView()

# Connect operators
pop.add_child(fc1, 0, 0)
fc1.add_child(lif1, 0, 0)
lif1.add_child(fc2, 0, 0)
fc2.add_child(lif2, 0, 0)
lif2.add_child(stack, 1, 0)


model.add({pop, fc1, lif1, fc2, lif2, stack})
model.compile("cpu", aidge_core.dtype.float32)

# Initialize parameters
aidge_core.init_producer(model, "Producer-1>FC", aidge_core.he_filler)
aidge_core.init_producer(
    model,
    "Producer-2>FC",
    lambda x: aidge_core.constant_filler(x, 0.0),
)

After the model is created, define the dataset pipeline.

The helper below downloads MNIST, applies preprocessing transforms, and returns train/test dataloaders with consistent batch sizing (drop_last=True) for fixed-shape tensor handling.

[ ]:
def create_mnist_loader():
    dtype = torch.float
    device = torch.device("cpu")
    data_path = "/tmp/data"

    transform = transforms.Compose(
        [
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,)),
        ]
    )

    mnist_train = datasets.MNIST(
        data_path, train=True, download=True, transform=transform
    )
    mnist_test = datasets.MNIST(
        data_path, train=False, download=True, transform=transform
    )

    # Create Dataloaders
    train_loader = DataLoader(mnist_train, batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(mnist_test, batch_size, shuffle=True, drop_last=True)

    return train_loader, test_loader

Now create the optimizer (Adam in this example).

You can replace Adam with another optimizer depending on your experiment.

The optimizer is configured with a constant learning rate scheduler and attached to FC weights and biases so those parameters are updated during backpropagation.

[ ]:
opt = aidge_learning.Adam()
opt.set_learning_rate_scheduler(aidge_learning.constant_lr(0.0005))
opt.set_parameters(
    [fc1.get_parent(1), fc1.get_parent(2), fc2.get_parent(1), fc2.get_parent(2)]
)

Then create the execution scheduler.

SequentialScheduler drives forward and backward passes in graph order and manages internal execution state between iterations.

[ ]:
scheduler = aidge_core.SequentialScheduler(model, False)

The training loop is similar to a standard ANN loop, with one key difference: the loss must account for multiple timesteps.

For each iteration, the notebook:

  1. Resets scheduling state

  2. Loads a batch and reshapes it into a timestep sequence

  3. Runs a forward pass

  4. Computes multiStepCELoss

  5. Calls scheduler.backward() then opt.update()

You should observe the reported loss trend downward across iterations.

Run the Learning Loop#

This final block performs a short demonstration run (10 iterations) to validate the full training pipeline.

[ ]:
train_loader, test_loader = create_mnist_loader()
device = torch.device("cpu")
for i in range(10):
    scheduler.reset_scheduling()

    test_data, test_targets = next(iter(test_loader))
    test_data = test_data.to(torch.device("cpu"))
    test_targets = test_targets.to(device)
    test_data = test_data.view(batch_size, -1).unsqueeze(0).repeat(num_steps, 1, 1)
    input_tensor = aidge_core.Tensor(test_data.cpu().numpy())

    targets = torch.nn.functional.one_hot(test_targets, 10).numpy().astype(np.float32)
    target_tensor = aidge_core.Tensor(targets)

    pop.get_operator().associate_input(0, input_tensor)
    result = scheduler.forward()[1]

    loss_value = aidge_learning.loss.multiStepCELoss(result, target_tensor, num_steps)
    print("Loss value : ", loss_value)

    scheduler.backward()
    opt.update()
[ ]: