from ..variables import Parameter
from .. import variables
from ..base import Layer
import torch
[docs]class Poisson(Layer):
"""Poisson spiking model.
Input has to be a firing rate between 0.0 and 1.0.
.. versionadded:: 0.6.4
"""
def __init__(self,**kwargs):
super(Poisson, self).__init__()
self.dims = 5
def forward(self, I_in):
return variables.Variable(torch.rand(*I_in.size()).float()) <= I_in.float()
_izhikevich_parameters = {
'Tonic spiking':
[0.02,0.2,-65.0,6.0,14.0],
'Phasic spiking':
[0.02,0.25,-65.0,6.0,0.5],
'Tonic bursting':
[0.02,0.2,-50.0,2.0,15.0],
'Phasic bursting':
[0.02,0.25,-55.0,0.05,0.6],
'Mixed mode':
[0.02,0.2,-55.0,4.0,10.0],
'Sf. adaptation':
[0.01,0.2,-65.0,8.0,30.0],
'Class 1':
[0.02,-0.1,-55.0,6.0,0.0],
'Class 2':
[0.2,0.26,-65.0,0.0,0.0],
'Spike latency':
[0.02,0.2,-65.0,6.0,7.0],
'Subthreshold osc':
[0.05,0.26,-60.0,0.0,0.0],
'Resonator':
[0.1,0.26,-60.0,-1.0,0.0],
'Integrator':
[0.02,-0.1,-55.0,6.0,0.0],
'Rebound spike':
[0.03,0.25,-60.0,4.0,0.0],
'Rebound burst':
[0.03,0.25,-52.0,0.0,0.0],
'Threshold var':
[0.03,0.25,-60.0,4.0,0.0],
'Bistability':
[1.0,1.5,-60.0,0.0,-65.0],
'DAP':
[1.0,0.2,-60.0,-21.0,0.0],
'Accomodation':
[0.02,1.0,-55.0,4.0,0.0],
'Inh-ind. spiking':
[-0.02,-1.0,-60.0,8.0,80.0],
'Inh-ind. bursting':
[-0.026,-1.0,-45.0,0.0,80.0]
}
[docs]class Izhikevich(Layer):
"""Izhikevich Spiking Model with uniform parameters
The Simple Model of Spiking Neurons after Eugene Izhikevich
offers a wide range of neural dynamics with very few parameters.
See: https://www.izhikevich.org/publications/spikes.htm
Each pixel has two state variables: `v` and `u`.
`v` corresponds roughly to the membrane potential of a neuron
and `u` to a slow acting ion concentration. Both
variables influence each other dynamically:
$$\\\\dot{v} = 0.04 \\\\cdot v^2 + 5 \\\\cdot v + 140 - u + I$$
$$\\\\dot{u} = a \\\\cdot (b \\\\cdot v - u)$$
If `v` crosses a threshold, it will be reset to a value `c` and
`u` will be increased by another value `d`.
The parameters of the model are:
- `a`: relative speed between the evolution of `v` and `u`
- `b`: amount of influence of `v` over `u`
- `c`: the reset value for `v` if it crosses threshold
- `d`: value to add to `u` if `v` crosses threshold
Parameters
----------
output_only_spikes (bool)
whether only spikes should be returned (binary),
or spikes, membrane potential and slow potential
in one channel of the output each.
Examples
--------
See also
--------
"""
def __init__(self,output_only_spikes=True,**kwargs):
super(Izhikevich, self).__init__()
self.dims = 5
# parameters
self.a = Parameter(0.02)
self.b = Parameter(0.2)
self.c = Parameter(-65.0, doc='reset value')
self.d = Parameter(6.0, doc='increase of u when v is above threshold')
self.threshold = Parameter(30.0)
self.noise_strength = Parameter(0.001)
self.register_state('v',None)
self.register_state('u',None)
self.iters = 2
self.output_only_spikes = output_only_spikes
[docs] def load_parameters_by_name(self, name=None):
"""Allows to load parameters for a range of behaviors.
For a list of possible options, run the method without
parameter or look at directly at the dictionary
`convis.filters.spiking._izhikevich_parameters`.
The dictionary has values for a,b,c,d and the recommended
input.
"""
if name in _izhikevich_parameters.keys():
p = _izhikevich_parameters[name]
self.a.set(p[0])
self.b.set(p[1])
self.c.set(p[2])
self.d.set(p[3])
print('Recommendet input: '+str(p[4]))
else:
print('Could not find key '+str(name))
print('Posibilities are:\n'+'\n - '.join(_izhikevich_parameters.keys()))
def init_states(self,input_shape):
self.zeros = variables.zeros((input_shape[3],input_shape[4]))
self.v = self.c*variables.ones((input_shape[3],input_shape[4])) # membrane potential
self.u = self.b*self.c*variables.ones((input_shape[3],input_shape[4])) # slow potential
def forward(self, I_in):
if not hasattr(self,'v') or self.v is None:
self.init_states(I_in.data.shape)
all_spikes = []
for t, I in enumerate(I_in.squeeze(0).squeeze(0)):
noise = variables.randn(I.data.shape).cpu()
noise_w = variables.randn(I.data.shape).cpu()
dt = 1.0/float(self.iters)
for i in range(self.iters):
dv = dt*(0.04*self.v*self.v + 5.0* self.v + 140 - self.u + I)
du = dt*self.a*(self.b*self.v - self.u)
self.v = self.v + dv
self.u = self.u + du
spikes = self.v >= self.threshold
self.v.masked_scatter_(spikes, self.zeros+self.c)
self.u.masked_scatter_(spikes, self.u+self.d)
spikes = spikes.float()
if self.output_only_spikes:
all_spikes.append(spikes[None,None,:,:])
else:
all_spikes.append(torch.cat([spikes[None,None,:,:],
self.v[None,None,:,:],
self.u[None,None,:,:]],dim=0))
return torch.cat(all_spikes,dim=1)[None,:,:,:,:]
[docs]class RefractoryLeakyIntegrateAndFireNeuron(Layer):
"""LIF model with refractory period.
Identical to `convis.filter.retina.GanglionSpiking`.
The ganglion cells recieve the gain controlled input and produce spikes.
When the cell is not refractory, :math:`V` moves as:
$$ \\\\dfrac{ dV_n }{dt} = I_{Gang}(x_n,y_n,t) - g^L V_n(t) + \eta_v(t)$$
Otherwise it is set to 0.
Attributes
----------
refr_mu : Parameter
The mean of the distribution of random refractory times (in seconds).
refr_sigma : Parameter
The standard deviation of the refractory time that is randomly drawn around `refr_mu`
noise_sigma : Parameter
Amount of noise added to the membrane potential.
g_L : Parameter
Leak current (in Hz or dimensionless firing rate).
See Also
--------
convis.retina.Retina
GanglionInput
LeakyIntegrateAndFireNeuron
"""
def __init__(self,**kwargs):
super(RefractoryLeakyIntegrateAndFireNeuron, self).__init__()
from .. import default_resolution
self.dims = 5
# parameters
self.refr_mu = Parameter(0.003,
retina_config_key='refr-mean__sec',
doc='The mean of the distribution of random refractory times (in seconds).')
self.refr_sigma = Parameter(0.001,
retina_config_key='refr-stdev__sec',
doc='The standard deviation of the refractory time that is randomly drawn around `refr_mu`')
self.noise_sigma = Parameter(0.1,
retina_config_key='sigma-V',
doc='Amount of noise added to the membrane potential.')
self.g_L = Parameter(50.0,
retina_config_key='g-leak__Hz',
doc='Leak current (in Hz or dimensionless firing rate).')
self.tau = Parameter(1.0/default_resolution.steps_per_second,
retina_config_key='--should be inherited',
doc = 'Length of timesteps (ie. the steps_to_seconds(1.0) of the model.')
self.register_state('V',None)
self.register_state('zeros',None)
self.register_state('refr',None)
self.register_state('noise_prev',None)
def init_states(self,input_shape):
self.zeros = variables.zeros((input_shape[3],input_shape[4]))
self.V = 0.5+0.2*variables.rand((input_shape[3],input_shape[4])) # membrane potential
if self._use_cuda:
self.refr = 1000.0*(self.refr_mu + self.refr_sigma *
variables.randn((input_shape[3],input_shape[4])).cuda())
else:
self.refr = 1000.0*(self.refr_mu + self.refr_sigma *
variables.randn((input_shape[3],input_shape[4])).cpu())
self.noise_prev = variables.zeros((input_shape[3],input_shape[4]))
def forward(self, I_gang):
g_infini = 50.0 # apparently?
if not hasattr(self,'V') or self.V is None:
self.init_states(I_gang.data.shape)
if self._use_cuda:
self.V = self.V.cuda()
self.refr = self.refr.cuda()
self.zeros = self.zeros.cuda()
self.noise_prev = self.noise_prev.cuda()
else:
self.V = self.V.cpu()
self.refr = self.refr.cpu()
self.zeros = self.zeros.cpu()
self.noise_prev = self.noise_prev.cpu()
all_spikes = []
for t, I in enumerate(I_gang.squeeze(0).squeeze(0)):
#print I.data.shape, self.V.data.shape,torch.randn(I.data.shape).shape
if self._use_cuda:
noise = variables.randn(I.data.shape).cuda()
else:
noise = variables.randn(I.data.shape).cpu()
V = self.V + (I/self.tau - self.g_L * self.V + self.noise_sigma*noise*torch.sqrt(self.g_L/self.tau))*self.tau
# canonical form:
#
# V = V + (E_L - V + R*I)*dt/tau
# + self.noise_sigma*noise*torch.sqrt(2.0*dt/tau)
# with dt = self.tau
# tau = 1/self.g_L
# R = tau
# E_L = 0
#
if self._use_cuda:
refr_noise = 1000.0*(self.refr_mu + self.refr_sigma *
variables.randn(I.data.shape).cuda())
else:
refr_noise = 1000.0*(self.refr_mu + self.refr_sigma *
variables.randn(I.data.shape).cpu())
spikes = V > 1.0
self.refr.masked_scatter_(spikes, refr_noise)
self.refr.masked_scatter_(self.refr < 0.0, self.zeros)
self.refr = self.refr - 1.0
V.masked_scatter_(self.refr >= 0.5, self.zeros)
V.masked_scatter_(spikes, self.zeros)
self.V = V
all_spikes.append(spikes[None,:,:])
return torch.cat(all_spikes,dim=0)[None,None,:,:,:]
[docs]class LeakyIntegrateAndFireNeuron(Layer):
"""LIF model.
$$ \\\\dfrac{ dV_n }{dt} = I_{Gang}(x_n,y_n,t) - g^L V_n(t) + \eta_v(t)$$
Attributes
----------
noise_sigma : Parameter
Amount of noise added to the membrane potential.
g_L : Parameter
Leak current (in Hz or dimensionless firing rate).
See Also
--------
RefractoryLeakyIntegrateAndFireNeuron
"""
def __init__(self,**kwargs):
super(LeakyIntegrateAndFireNeuron, self).__init__()
from .. import default_resolution
self.dims = 5
# parameters
self.noise_sigma = Parameter(0.1,
retina_config_key='sigma-V',
doc='Amount of noise added to the membrane potential.')
self.g_L = Parameter(50.0,
retina_config_key='g-leak__Hz',
doc='Leak current (in Hz or dimensionless firing rate).')
self.tau = Parameter(1.0/default_resolution.steps_per_second,
retina_config_key='--should be inherited',
doc = 'Length of timesteps (ie. the steps_to_seconds(1.0) of the model.')
self.register_state('V',None)
self.register_state('zeros',None)
self.register_state('noise_prev',None)
def init_states(self,input_shape):
self.zeros = variables.zeros((input_shape[3],input_shape[4]))
self.V = 0.5+0.2*variables.rand((input_shape[3],input_shape[4])) # membrane potential
self.noise_prev = variables.zeros((input_shape[3],input_shape[4]))
def forward(self, I_gang):
g_infini = 50.0 # apparently?
if not hasattr(self,'V') or self.V is None:
self.init_states(I_gang.data.shape)
if self._use_cuda:
self.V = self.V.cuda()
self.zeros = self.zeros.cuda()
self.noise_prev = self.noise_prev.cuda()
else:
self.V = self.V.cpu()
self.zeros = self.zeros.cpu()
self.noise_prev = self.noise_prev.cpu()
all_spikes = []
for t, I in enumerate(I_gang.squeeze(0).squeeze(0)):
if self._use_cuda:
noise = variables.randn(I.data.shape).cuda()
else:
noise = variables.randn(I.data.shape).cpu()
V = self.V + (I/self.tau - self.g_L * self.V + self.noise_sigma*noise*torch.sqrt(self.g_L/self.tau))*self.tau
spikes = V > 1.0
V.masked_scatter_(spikes, self.zeros)
self.V = V
all_spikes.append(spikes[None,:,:])
return torch.cat(all_spikes,dim=0)[None,None,:,:,:]
[docs]class FitzHughNagumo(Layer):
"""Two state neural model.
$$\\\\dot{v} = v - \\\\frac{1}{3} v^3 - w + I$$
$$\\\\dot{w} = \\\\tau \\\\cdot (v - a - b \\\\cdot w) $$
See also:
- `Wikipedia on FitzHugh-Nagumo models <https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model>`_
"""
def __init__(self,**kwargs):
super(FitzHughNagumo, self).__init__()
self.dims = 5
self.a = Parameter(0.7)
self.b = Parameter(0.8)
self.tau = Parameter(0.08)
self.noise_strength = Parameter(0.001)
self.register_state('v',None)
self.register_state('w',None)
self.iters = 10
def init_states(self,input_shape):
self.v = 0.0*variables.randn((input_shape[3],input_shape[4])) # membrane potential
self.w = 0.0*variables.randn((input_shape[3],input_shape[4])) # slow potential
def forward(self, I_in):
if not hasattr(self,'v') or self.v is None:
self.init_states(I_in.data.shape)
all_spikes = []
for t, I in enumerate(I_in.squeeze(0).squeeze(0)):
noise = variables.randn(I.data.shape).cpu()
noise_w = variables.randn(I.data.shape).cpu()
dt = 1.0/float(self.iters)
for i in range(self.iters):
dv = dt*(I + self.v - (self.v**3.0)/3.0 - self.w)
dw = dt*(self.v - self.a -self.b*self.w)*self.tau
self.v = self.v + dv + self.noise_strength * dt*noise
self.w = self.w + dw + self.noise_strength * dt*noise_w
all_spikes.append(self.v[None,:,:])
return torch.cat(all_spikes,dim=0)[None,None,:,:,:]
[docs]class HogkinHuxley(Layer):
"""Neuron model of the giant squid axon.
This model contains four state variables:
the membrane potential `v` and three slow acting currents `n`, `m` and `h`.
See also:
- `Wikipedia in Hodgkin-Huxley models <https://en.wikipedia.org/wiki/Hodgkin%E2%80%93Huxley_model>`_
- `<http://neuronaldynamics.epfl.ch/online/Ch2.S2.html>`_
"""
def __init__(self,**kwargs):
super(HogkinHuxley, self).__init__()
self.dims = 5
# parameters
self.gK = Parameter(36.0, doc='maximum conductance of Potassium channels')
self.gNa = Parameter(120.0, doc='maximum conductance of Sodium channels')
self.gL = Parameter(0.3, doc='leak current')
self.Cm = Parameter(1.0, doc='membrane capacitance')
self.VK = Parameter(-12.0,doc='potential of Potassium')
self.VNa = Parameter(115.0,doc='potential of Sodium')
self.Vl = Parameter(10.613,doc='potential of leak currents')
self.noise_strength = Parameter(0.1)
self.alpha_n = lambda v: (0.01 * (10.0 - v)) / (torch.exp(1.0 - (0.1 * v)) - 1.0)
self.beta_n = lambda v : 0.125 * torch.exp(-v / 80.0)
self.alpha_m = lambda v : (0.1 * (25.0 - v)) / (torch.exp(2.5 - (0.1 * v)) - 1.0)
self.beta_m = lambda v : 4.0 * torch.exp(-v / 18.0)
self.alpha_h = lambda v : 0.07 * torch.exp(-v / 20.0)
self.beta_h = lambda v : 1.0 / (torch.exp(3.0 - (0.1 * v)) + 1.0)
self.n_inf = lambda v : alpha_n(v) / (alpha_n(v) + beta_n(v))
self.m_inf = lambda v : alpha_m(v) / (alpha_m(v) + beta_m(v))
self.h_inf = lambda v : alpha_h(v) / (alpha_h(v) + beta_h(v))
self.register_state('v',None)
self.register_state('v_n',None)
self.register_state('v_m',None)
self.register_state('v_h',None)
self.iters = 20
def init_states(self,input_shape):
self.v = 0.0*variables.randn((input_shape[3],input_shape[4])) # membrane potential
self.v_n = 0.0*variables.randn((input_shape[3],input_shape[4])) # slow potential
self.v_m = 0.0*variables.randn((input_shape[3],input_shape[4])) # slow potential
self.v_h = 0.0*variables.randn((input_shape[3],input_shape[4])) # slow potential
def forward(self, I_in):
if not hasattr(self,'v') or self.v is None:
self.init_states(I_in.data.shape)
all_spikes = []
for t, I in enumerate(I_in.squeeze(0).squeeze(0)):
noise = variables.randn(I.data.shape).cpu()
noise_w = variables.randn(I.data.shape).cpu()
dt = 1.0/float(self.iters)
for i in range(self.iters):
dv = ((I + self.noise_strength*noise / self.Cm)
- ((self.gK / self.Cm) * self.v_n**4.0 * (self.v - self.VK))
- ((self.gNa / self.Cm) * self.v_m**3.0 * self.v_h * (self.v - self.VNa))
- (self.gL / self.Cm * (self.v - self.Vl))
)
dn = (self.alpha_n(self.v) * (1.0 - self.v_n)) - (self.beta_n(self.v) * self.v_n)
dm = (self.alpha_m(self.v) * (1.0 - self.v_m)) - (self.beta_m(self.v) * self.v_m)
dh = (self.alpha_h(self.v) * (1.0 - self.v_h)) - (self.beta_h(self.v) * self.v_h)
self.v = self.v + dt*dv
self.v_n = self.v_n + dt*dn
self.v_m = self.v_m + dt*dm
self.v_h = self.v_h + dt*dh
all_spikes.append(self.v[None,:,:])
return torch.cat(all_spikes,dim=0)[None,None,:,:,:]
[docs]class IntegrativeMotionSensor(Layer):
"""A spiking integrator that will fire and readjust its threshold in 'linear' or 'log'(arithmic) mode.
The output of the Layer has two channels: On and Off spikes.
On spikes are fired if the values surpass a positive threshold.
Off spikes are fired if the values fall below a negative threshold (`-threshold` for 'linear and `1/threshold` for 'log').
.. versionadded:: 0.6.4
Parameters
----------
spiking_mode : str
'linear' or 'log'
threshold : float
The dynamic threshold in/decreases.
In 'linear' mode, the threshold is increased by adding `threshold`.
In 'log' mode, the threshold is increased by multiplying with `threshold`.
"""
def __init__(self,spiking_mode='linear',threshold=1.0):
self.dim = 5
super(IntegrativeMotionSensor,self).__init__()
self.register_state('last_frame',None)
self.threshold = threshold
self.spiking_mode = spiking_mode
def forward(self,inp):
if self.last_frame is not None:
frame = self.last_frame
else:
if self.spiking_mode is 'linear':
frame = torch.zeros((inp.size()[0],inp.size()[1],1,inp.size()[3],inp.size()[4]))
elif self.spiking_mode is 'log':
frame = 0.5*torch.ones((inp.size()[0],inp.size()[1],1,inp.size()[3],inp.size()[4]))
else:
raise Exception('`spiking_mode` %s not recognized.'%(self.spiking_mode))
spikes_on = []
spikes_off = []
for i in range(inp.size()[2]):
if self.spiking_mode is 'linear':
frame += inp[:,:,i,:,:]
spikes_on.append(frame > self.threshold)
spikes_off.append(frame < -self.threshold)
frame.masked_scatter_(spikes_on[-1], frame-self.threshold)
frame.masked_scatter_(spikes_off[-1], frame+self.threshold)
elif self.spiking_mode is 'log':
spikes_on.append(inp[:,:,i,:,:]/frame > 1.0*self.threshold)
spikes_off.append(inp[:,:,i,:,:]/frame < 1.0/self.threshold)
if self.threshold <= 1.0:
raise Exception("`threshold` has to be larger than 1 for logarithmic `spiking_mode`!")
frame.masked_scatter_(spikes_on[-1], frame*self.threshold)
frame.masked_scatter_(spikes_off[-1], frame/self.threshold)
else:
raise Exception('`spiking_mode` %s not recognized.'%(self.spiking_mode))
self.last_frame = frame
spikes_on = torch.cat(spikes_on,dim=2)
spikes_off = torch.cat(spikes_off,dim=2)
return (torch.cat([spikes_on,spikes_off],dim=1)).float()