Source code for arcana.surrogate
import torch
[docs]
class FastSigmoid(torch.autograd.Function):
"""Fast-sigmoid surrogated gradient
Apply the fast-sigmoid gradient as a surrogated gradient \
for the heavyside step function.
.. math::
\\frac{\\partial S}{\\partial V} = \\frac{1}{(\\lambda \\left|v\\right| + 1.0)^2}
Where :math:`\\lambda` is a scale factor with default value 10.
"""
scale = 10 # Scale value applied to fast sigmoid
[docs]
@staticmethod
def pseudo_derivative(v):
"""Compute the gradient of the fast-sigmoid function.
Args:
V(float): Neuron voltage to which threshold is applied to.
Returns:
float: The fast-sigmoid gradient of V.
"""
return 1.0 / (FastSigmoid.scale * torch.abs(v) + 1.0) ** 2
@staticmethod
def forward(ctx, V):
""""""
ctx.save_for_backward(V)
return (V >= 0).type(V.dtype)
@staticmethod
def backward(ctx, dy):
""""""
(V,) = ctx.saved_tensors
dE_dz = dy
dz_dv_scaled = FastSigmoid.pseudo_derivative(V)
dE_dv_scaled = dE_dz * dz_dv_scaled
return dE_dv_scaled
fast_sigmoid = FastSigmoid.apply
[docs]
class Step(torch.autograd.Function):
"""Step function surrogated gradient
Use the step function as a surrogated gradient of itself
"""
[docs]
@staticmethod
def pseudo_derivative(V):
"""Compute the step function surrogate gradient.
Args:
V(float): Neuron voltage to which threshold is applied to.
Returns:
float: The surrogate triangular gradient of V.
"""
return (V >= 0).type(V.dtype)
@staticmethod
def forward(ctx, V):
""""""
ctx.save_for_backward(V)
return (V >= 0).type(V.dtype)
@staticmethod
def backward(ctx, dy):
""""""
(V,) = ctx.saved_tensors
dE_dz = dy
dz_dv_scaled = Step.pseudo_derivative(V)
dE_dv_scaled = dE_dz * dz_dv_scaled
return dE_dv_scaled
step = Step.apply
[docs]
class Triangular(torch.autograd.Function):
"""Triangular surrogated gradient
Apply the triangular function as a surrogated gradient \
for the heavyside step function.
.. math::
\\frac{\\partial S}{\\partial V} = \\lambda max(1 - \\left|V\\right|, 0)
Where :math:`\\lambda` is a scale factor with default value 0.3.
"""
scale = 0.3 # Scale value applied to fast sigmoid
[docs]
@staticmethod
def pseudo_derivative(v):
"""Compute the triangular surrogate gradient.
Args:
V(float): Neuron voltage to which threshold is applied to.
Returns:
float: The surrogate triangular gradient of V.
"""
return torch.maximum(1 - torch.abs(v), torch.tensor(0)) * Triangular.scale
@staticmethod
def forward(ctx, V):
""""""
ctx.save_for_backward(V)
return (V >= 0).type(V.dtype)
@staticmethod
def backward(ctx, dy):
""""""
(V,) = ctx.saved_tensors
dE_dz = dy
dz_dv_scaled = Triangular.pseudo_derivative(V)
dE_dv_scaled = dE_dz * dz_dv_scaled
return dE_dv_scaled
triangular = Triangular.apply
[docs]
class STE(torch.autograd.Function):
"""
Straight Through Estimator
"""
[docs]
@staticmethod
def pseudo_derivative(v):
"""Compute the STE surrogate gradient.
Args:
V(float): Neuron voltage to which threshold is applied to.
Returns:
float: The surrogate STE gradient of V.
"""
return torch.ones_like(v)
@staticmethod
def forward(ctx, V):
""""""
ctx.save_for_backward(V)
return (V >= 0).type(V.dtype)
@staticmethod
def backward(ctx, dy):
""""""
(V,) = ctx.saved_tensors
dE_dz = dy
dz_dv_scaled = STE.pseudo_derivative(V)
dE_dv_scaled = dE_dz * dz_dv_scaled
return dE_dv_scaled
ste = STE.apply