from typing import Callable
import numpy as np
import torch
import torch.nn as nn
from arcana.surrogate import fast_sigmoid
[docs]
class Round(torch.autograd.function.InplaceFunction):
"""
Round operation with straight through stimator surrogate gradient.
"""
[docs]
@staticmethod
def forward(ctx, input):
ctx.input = input
return torch.round(input)
[docs]
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
return grad_input
round = Round.apply
SCALING = 1
[docs]
class DPINeuron(nn.Module):
"""
DPI neuron model used in Dynap-SE chip, including AMPA, NMDA, GABAa and GABAb synapses.
The bias parameters of the neurons areÑ
| Parameter | Description |
| :-------- | :---------- |
| Itau_mem | Soma leakage current |
| Igain_mem | Soma gain current |
| Ipfb_th | Positive feedback bias |
| Ipfb_norm | Positive feedback normalization |
| refractory | Refractory period |
| Ith | Firing threshold |
| Idc | Input constant current |
| Itau_ampa | AMPA synapse leakage current |
| Igain_ampa | AMPA synapse gain current |
| Iw_ampa | AMPA synapse base weight |
| Inmda_thr | NMDA synapse threshold |
| Itau_nmda | NMDA synapse leakage current |
| Igain_nmda | NMDA synapse gain current |
| Iw_nmda | NMDA synapse base weight |
| Itau_gabaa | GABAa synapse leakage current |
| Igain_gabaa | GABAa synapse gain current |
| Iw_gabaa | GABAa synapse base weight |
| Itau_gabab | GABAb synapse leakage current |
| Igain_gabab | GABAb synapse gain current |
| Iw_gabab | GABAb synapse base weight |
Parameters
----------
n_in: int
Number of input synapses
n_out: int
Number of neuron in the layer
dt: float
Simulation timestep in seconds
surrogate_fn: Callable
Surrogate gradient function for spiking
train_Itau_mem: bool
Flag to train the membrane leakage current bias
train_Igain_mem: bool
Flag to train the membrane input gain bias
train_Idc: bool
Flag to train the input constant current
train_ampa: bool
Flag to train the ampa weight matrix
train_gabab: bool
Flag to train the gaba_b weight matrix
"""
I0: float = 0.5e-13 * SCALING # Dark current
UT: float = 25e-3 # Thermal voltage
KAPPA: float = (0.75 + 0.66) / 2 # Transistor slope factor
CMEM: float = 3e-12 * SCALING # Membrane capacitance
CAMPA: float = 2e-12 * SCALING # AMPA synapse capacitance
CNMDA: float = 2e-12 * SCALING # AMPA synapse capacitance
CGABA_A: float = 2e-12 * SCALING # AMPA synapse capacitance
CGABA_B: float = 2e-12 * SCALING # AMPA synapse capacitance
MAX_FANIN: float = 64 # Maximum number of input synapses per neuron
def __init__(
self,
n_in: int,
n_out: int,
dt: float = 1e-3,
surrogate_fn: Callable = fast_sigmoid,
train_Itau_mem: bool = False,
train_Igain_mem: bool = False,
train_Idc: bool = False,
train_ampa: bool = False,
train_gabab: bool = False,
**kwargs,
):
super(DPINeuron, self).__init__()
self.n_in = n_in
self.n_out = n_out
self.surrogate_fn = surrogate_fn
self.dt = dt
# SOMA
# Parameters
self.Itau_mem = kwargs.get("Itau_mem", 5e-12) * SCALING
self.Igain_mem = kwargs.get("Igain_mem", 20e-12) * SCALING
# Alpha and beta are trainable parameters
# that depends on leakage and gain current
self.alpha = nn.Parameter(
torch.tensor(self.Igain_mem / self.Itau_mem), requires_grad=train_Igain_mem
)
self.beta = nn.Parameter(
torch.tensor(1 + DPINeuron.I0 / self.Itau_mem), requires_grad=train_Itau_mem
)
# Positive feedback current
self.Ipfb_th = kwargs.get("Ipfb_th", 500.0e-12) * SCALING
self.Ipfb_norm = kwargs.get("Ipfb_norm", 1.470e9) / SCALING
# Other neuron parameters
self.refP = kwargs.get("refractory", 0.0)
self.Ith = kwargs.get("Ith", 2000.0e-12) * SCALING # Firing threshold
self.Idc = nn.Parameter(
torch.tensor(kwargs.get("Idc", 1e-12) * SCALING), requires_grad=train_Idc
)
# AMPA
self.train_ampa = train_ampa
self.Itau_ampa = kwargs.get("Itau_ampa", 20e-12) * SCALING
self.Igain_ampa = kwargs.get("Igain_ampa", 80e-12) * SCALING
self.Iw_ampa = nn.Parameter(
torch.tensor(kwargs.get("Iw_ampa", 80e-12) * SCALING),
requires_grad=train_ampa,
)
if train_ampa:
self.Iw_ampa.register_hook(lambda grad: grad * 1e-12)
self.W_ampa = nn.Parameter(torch.empty(n_out, n_in), requires_grad=train_ampa)
# NMDA
# self.train_ampa = train_ampa
self.Itau_nmda = kwargs.get("Itau_nmda", 20e-12) * SCALING
self.Inmda_thr = kwargs.get("Inmda_thr", 5e-13) * SCALING
self.Igain_nmda = kwargs.get("Igain_nmda", 80e-12) * SCALING
self.Iw_nmda = torch.tensor(kwargs.get("Iw_nmda", 80e-12) * SCALING)
self.W_nmda = nn.Parameter(torch.empty(n_out, n_in), requires_grad=train_ampa)
# gabaa
self.Itau_gabaa = kwargs.get("Itau_gabaa", 20e-12) * SCALING
self.Igain_gabaa = kwargs.get("Igain_gabaa", 80e-12) * SCALING
self.Iw_gabaa = torch.tensor(kwargs.get("Iw_gabaa", 80e-12) * SCALING)
self.W_gabaa = nn.Parameter(torch.empty(n_out, n_in), requires_grad=train_ampa)
# gabab
self.train_gabab = train_gabab
self.Itau_gabab = kwargs.get("Itau_gabab", 20e-12) * SCALING
self.Igain_gabab = kwargs.get("Igain_gabab", 80e-12) * SCALING
self.Iw_gabab = nn.Parameter(
torch.tensor(kwargs.get("Iw_gabab", 80e-12) * SCALING),
requires_grad=train_gabab,
)
if train_gabab:
self.Iw_gabab.register_hook(lambda grad: grad * 1e-12)
self.W_gabab = nn.Parameter(torch.empty(n_out, n_in), requires_grad=train_ampa)
# Weights initialization
nn.init.constant_(self.W_ampa, 1.0)
nn.init.constant_(self.W_nmda, 0.0)
nn.init.constant_(self.W_gabaa, 0.0)
nn.init.constant_(self.W_gabab, 1.0)
# Mismatch parameters
self.register_buffer("_Idc_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Itau_mem_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Igain_mem_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Itau_ampa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Igain_ampa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Iw_ampa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Itau_nmda_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Inmda_thr_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Igain_nmda_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Iw_nmda_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Itau_gabaa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Igain_gabaa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Iw_gabaa_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Itau_gabab_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Igain_gabab_mismatch", torch.zeros(1, self.n_out))
self.register_buffer("_Iw_gabab_mismatch", torch.zeros(1, self.n_out))
self.state = None
[docs]
def add_mismatch(self, param: str, mismatch: float = 0.1):
"""
The DPINeuron can include mismatch in the biases by calling this function with the name of the bias and the percentage of variability.
Parameters
----------
param: Bias name to apply the mismatch.
mismatch: Percentage of variability to apply [0-1]
"""
try:
mismatch_parameter = getattr(self, f"_{param}_mismatch")
torch.nn.init.normal_(mismatch_parameter, mean=0.0, std=mismatch)
except AttributeError:
raise AttributeError(f"DPINeuron as not mismatch on attribute {param}")
[docs]
def initialize(self, X):
"""
Initialize the internal state of the DPINeuron based on an input sample.
"""
Imem = torch.zeros(X.shape[0], self.n_out, device=X.device) + self.I0
Iampa = torch.zeros(X.shape[0], self.n_out, device=X.device) + self.I0
Inmda = torch.zeros(X.shape[0], self.n_out, device=X.device) + self.I0
Igabaa = torch.zeros(X.shape[0], self.n_out, device=X.device) + self.I0
Igabab = torch.zeros(X.shape[0], self.n_out, device=X.device) + self.I0
refractory = torch.zeros(X.shape[0], self.n_out, device=X.device)
return (Imem, Iampa, Inmda, Igabaa, Igabab, refractory)
[docs]
@staticmethod
def I2V(current: float) -> float:
"""
Convert a current value into a voltage based on the parameters of the DPINeuron
"""
return (DPINeuron.UT / DPINeuron.KAPPA) * torch.log(current / DPINeuron.I0)
[docs]
@staticmethod
def V2I(voltage: float) -> float:
"""
Convert a voltage value into a current based on the parameters of the DPINeuron
"""
return DPINeuron.I0 * np.exp(voltage * DPINeuron.KAPPA / DPINeuron.UT)
[docs]
def UpdateParams(self, optimizer, args, kwargs):
"""
Call each time to gradient optimization is called to update correctly the neuron parameters.
"""
self.Itau_mem = DPINeuron.I0 / (self.beta - 1)
self.Igain_mem = self.alpha * self.Itau_mem
self.tau_mem = (DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CMEM / self.Itau_mem
self.Iw_ampa.data = torch.clamp_min(self.Iw_ampa.data, self.I0)
self.Iw_nmda.data = torch.clamp_min(self.Iw_nmda.data, self.I0)
self.Iw_gabaa.data = torch.clamp_min(self.Iw_gabaa.data, self.I0)
self.Iw_gabab.data = torch.clamp_min(self.Iw_gabab.data, self.I0)
self.W_ampa.data = torch.clamp_min(self.W_ampa.data, 0.0)
self.W_nmda.data = torch.clamp_min(self.W_nmda.data, 0.0)
self.W_gabaa.data = torch.clamp_min(self.W_gabaa.data, 0.0)
self.W_gabab.data = torch.clamp_min(self.W_gabab.data, 0.0)
[docs]
def forward(self, X, state=None):
if state is None:
state = self.initialize(X)
(Imem, Iampa, Inmda, Igabaa, Igabab, refractory) = state
Iahp = DPINeuron.I0
# Apply mismatch
Idc = torch.clamp_min(self.Idc * (1 + self._Idc_mismatch), DPINeuron.I0)
Itau_mem = torch.clamp_min(
self.Itau_mem * (1 + self._Itau_mem_mismatch), DPINeuron.I0
)
Igain_mem = torch.clamp_min(
self.Igain_mem * (1 + self._Igain_mem_mismatch), DPINeuron.I0
)
Itau_ampa = torch.clamp_min(
self.Itau_ampa * (1 + self._Itau_ampa_mismatch), DPINeuron.I0
)
Igain_ampa = torch.clamp_min(
self.Igain_ampa * (1 + self._Igain_ampa_mismatch), DPINeuron.I0
)
Iw_ampa = torch.clamp_min(
self.Iw_ampa * (1 + self._Iw_ampa_mismatch), DPINeuron.I0
)
Inmda_thr = torch.clamp_min(
self.Inmda_thr * (1 + self._Inmda_thr_mismatch), DPINeuron.I0
)
Itau_nmda = torch.clamp_min(
self.Itau_nmda * (1 + self._Itau_nmda_mismatch), DPINeuron.I0
)
Igain_nmda = torch.clamp_min(
self.Igain_nmda * (1 + self._Igain_nmda_mismatch), DPINeuron.I0
)
Iw_nmda = torch.clamp_min(
self.Iw_nmda * (1 + self._Iw_nmda_mismatch), DPINeuron.I0
)
Itau_gabaa = torch.clamp_min(
self.Itau_gabaa * (1 + self._Itau_gabaa_mismatch), DPINeuron.I0
)
Igain_gabaa = torch.clamp_min(
self.Igain_gabaa * (1 + self._Igain_gabaa_mismatch), DPINeuron.I0
)
Iw_gabaa = torch.clamp_min(
self.Iw_gabaa * (1 + self._Iw_gabaa_mismatch), DPINeuron.I0
)
Itau_gabab = torch.clamp_min(
self.Itau_gabab * (1 + self._Itau_gabab_mismatch), DPINeuron.I0
)
Igain_gabab = torch.clamp_min(
self.Igain_gabab * (1 + self._Igain_gabab_mismatch), DPINeuron.I0
)
Iw_gabab = torch.clamp_min(
self.Iw_gabab * (1 + self._Iw_gabab_mismatch), DPINeuron.I0
)
alpha = (
self.alpha * (1 + self._Igain_mem_mismatch) / (1 + self._Itau_mem_mismatch)
)
beta_mismatch = (
DPINeuron.I0 + self.Itau_mem * (1 + self._Itau_mem_mismatch)
) / ((self._Itau_mem_mismatch + 1) * (DPINeuron.I0 + self.Itau_mem))
beta = self.beta * beta_mismatch.detach()
# Calculate tau
tau_mem = (
(DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CMEM
) / Itau_mem # AMPA time constant
tau_ampa = (
(DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CAMPA
) / Itau_ampa # AMPA time constant
tau_nmda = (
(DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CNMDA
) / Itau_nmda # AMPA time constant
tau_gabaa = (
(DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CGABA_A
) / Itau_gabaa # AMPA time constant
tau_gabab = (
(DPINeuron.UT / DPINeuron.KAPPA) * DPINeuron.CGABA_B
) / Itau_gabab # AMPA time constant
# Synapse
numSynAmpa = torch.nn.functional.linear(X, round(self.W_ampa))
numSynNmda = torch.nn.functional.linear(X, round(self.W_nmda))
numSynGabaa = torch.nn.functional.linear(X, round(self.W_gabaa))
numSynGabab = torch.nn.functional.linear(X, round(self.W_gabab))
if self.training and self.train_ampa:
numSynAmpa.register_hook(lambda grad: grad * 1e10)
numSynNmda.register_hook(lambda grad: grad * 1e10)
numSynGabaa.register_hook(lambda grad: grad * 1e10)
numSynGabab.register_hook(lambda grad: grad * 1e10)
# Synapse derivatives
dIampa = -Iampa / tau_ampa
Iampa = Iampa + (Igain_ampa / Itau_ampa) * Iw_ampa * numSynAmpa
dInmda = -Inmda / tau_nmda
Inmda = Inmda + (Igain_nmda / Itau_nmda) * Iw_nmda * numSynNmda
dIgabaa = -Igabaa / tau_gabaa
Igabaa = Igabaa + (Igain_gabaa / Itau_gabaa) * Iw_gabaa * numSynGabaa
dIgabab = -Igabab / tau_gabab
Igabab = Igabab + (Igain_gabab / Itau_gabab) * Iw_gabab * numSynGabab
# Soma
# Input current
Inmda_dp = Inmda / (1 + Inmda_thr / Imem)
Iin = Idc + Iampa + Inmda_dp.detach() - Igabab
Iin = Iin * (refractory <= 0)
Iin = torch.clamp_min(Iin, DPINeuron.I0)
# Positive feedback
Ifb = (
DPINeuron.I0 ** (1 / (DPINeuron.KAPPA + 1))
* Imem ** (DPINeuron.KAPPA / (DPINeuron.KAPPA + 1))
/ (1 + torch.exp(-self.Ipfb_norm * (Imem - self.Ipfb_th)))
)
f_imem = (Ifb / Itau_mem) * (Imem + Igain_mem)
# Soma derivative
dImem = (
alpha * (Iin - Itau_mem - Iahp - Igabaa.detach())
- beta * Imem
- ((Igabaa / Itau_mem) * Imem).detach()
+ f_imem.detach()
) / (tau_mem * (1 + Igain_mem / Imem))
# Gradient update
Imem = Imem + dImem * self.dt
Imem = torch.clamp_min(Imem, DPINeuron.I0)
Iampa = Iampa + dIampa * self.dt
Iampa = torch.clamp_min(Iampa, DPINeuron.I0)
Inmda = Inmda + dInmda * self.dt
Inmda = torch.clamp_min(Inmda, DPINeuron.I0)
Igabaa = Igabaa + dIgabaa * self.dt
Igabaa = torch.clamp_min(Igabaa, DPINeuron.I0)
Igabab = Igabab + dIgabab * self.dt
Igabab = torch.clamp_min(Igabab, DPINeuron.I0)
# Spike
spike = self.surrogate_fn(Imem - self.Ith)
Imem = (1.0 - spike) * Imem + spike * DPINeuron.I0
refractory = refractory - self.dt
refractory = torch.clamp_min(refractory, 0.0)
refractory = (1.0 - spike) * refractory + spike * self.refP
# Save state
state = (Imem, Iampa, Inmda, Igabaa, Igabab, refractory)
return spike, state
[docs]
class ADM(nn.Module):
"""Adaptive Delta Modulation (ADM) module
Converts an analog signal into UP and DOWN \
spikes using the Adaptive Delta Modulation scheme.
"""
def __init__(
self,
N: int,
threshold_up: float,
threshold_down: float,
refractory: int,
surrogate_fn: Callable = fast_sigmoid,
):
"""
Parameters
----------
N: int
Number of input synapses
threshold_up: float
Threshold for UP spike
threshold_down: float
Threshold for DOWN spike
refractory: int
Refractory period
surrogate_fn: Callable
Surrogate gradient function for spiking
"""
super(ADM, self).__init__()
self.activation_fn = surrogate_fn
self.refractory = nn.Parameter(
torch.tensor(refractory).float(), requires_grad=True
)
self.N = N
self.thr_up = nn.Parameter(torch.tensor(threshold_up), requires_grad=True)
self.thr_down = nn.Parameter(torch.tensor(threshold_down), requires_grad=True)
self.reset()
def reset(self):
self.refrac = None
self.DC_Voltage = None
[docs]
def reconstruct(self, spikes, initial_value=0):
"""Reconstruct an analog signal based on the UP and DOWN \
spikes produced by the ADM module.
Everytime the algorithm receives an UP/DOWN spike, the \
reconstructed signal is increment/decrement by the UP/DOWN threshold amount.
Parameters
----------
spikes: Input spikes from where the signal is reconstructed.
initial_value: Initial reconstructed signal value.
"""
reconstructed = torch.zeros(
spikes.shape[0], spikes.shape[1], spikes.shape[2] // 2
)
reconstructed[:, 0, :] = initial_value
for t in range(1, spikes.shape[1]):
spikes_p = spikes[:, t, : -spikes.shape[-1] // 2]
spikes_n = spikes[:, t, spikes.shape[-1] // 2 :]
reconstructed[:, t] = (
reconstructed[:, t - 1]
+ self.thr_up * spikes_p
- self.thr_down * spikes_n
)
return reconstructed
[docs]
def forward(self, input_signal):
if self.DC_Voltage is None:
output = torch.zeros(
input_signal.shape[0], self.N * 2, device=input_signal.device
)
output_p = torch.zeros_like(input_signal)
output_n = torch.zeros_like(input_signal)
self.refrac = torch.zeros_like(input_signal)
self.DC_Voltage = input_signal
else:
output_p = (
self.activation_fn(
(input_signal - (self.DC_Voltage.detach() + self.thr_up))
)
* (self.refrac == 0).float()
)
self.refrac = output_p * self.refractory + (1 - output_p) * self.refrac
output_n = (
self.activation_fn(
((self.DC_Voltage.detach() - self.thr_down) - input_signal)
)
* (self.refrac == 0).float()
)
self.refrac = output_n * self.refractory + (1 - output_n) * self.refrac
change_v = (self.refrac == 1).float()
self.DC_Voltage = change_v * input_signal + (1 - change_v) * self.DC_Voltage
output = torch.cat([output_p, output_n], dim=1)
self.refrac = torch.nn.functional.relu(self.refrac - 1)
return output, output_p, output_n