Source code for convis.retina

# -*- coding: utf-8 -*-
"""
This module provides the retina model.
"""
from __future__ import print_function
from .base import Layer, Model,Output
from .retina_virtualretina import RetinaConfiguration
from .filters import retina as rf

[docs]class Retina(Layer): """ A retinal ganglion cell model comparable to VirtualRetina [Wohrer2009]_. .. [Wohrer2009] Wohrer, A., & Kornprobst, P. (2009). Virtual Retina: a biological retina model and simulator, with contrast gain control. Journal of Computational Neuroscience, 26(2), 219-49. http://doi.org/10.1007/s10827-008-0108-4 Attributes ---------- opl : Layer (convis.filters.retina.OPL) bipolar : Layer (convis.filters.retina.Bipolar) gang_0_input : Layer (convis.filters.retina.GanglionInput) gang_0_spikes : Layer (convis.filters.retina.GanglionSpiking) gang_1_input : Layer (convis.filters.retina.GanglionInput) gang_1_spikes : Layer (convis.filters.retina.GanglionSpiking) _timing : list of tuples timing information of the last run (last chunk) Each entry is a tuple of (function that was executed, number of seconds it took to execute) keep_timing_info : bool whether to store all timing information in a list timing_info : list stores timing information of all runs if `keep_timing_info` is True. Examples -------- .. plot:: :include-source: import convis import numpy as np from matplotlib import pylab as plt retina = convis.retina.Retina() inp = convis.samples.moving_grating(2000) inp = np.concatenate([inp[:1000],2.0*inp[1000:1500],inp[1500:2000]],axis=0) o_init = retina.run(inp[:500],dt=200) o = retina.run(inp[500:],dt=200) convis.plot_5d_time(o[0],mean=(3,4)) # plots the mean activity over time plt.figure() retina = convis.retina.Retina(opl=True,bipolar=False,gang=True,spikes=False) o_init = retina.run(inp[:500],dt=200) o = retina.run(inp[500:],dt=200) convis.plot_5d_time(o[0],mean=(3,4)) # plots the mean activity over time convis.plot_5d_time(o[0],alpha=0.1) # plots a line for each pixel See Also -------- convis.base.Layer : The Layer base class, providing chunking and optimization convis.filters.retina.OPL : The outer plexiform layer performs luminance to contrast conversion convis.filters.retina.Bipolar : provides contrast gain control convis.filters.retina.GanglionInput : provides a static non-linearity and a last spatial integration convis.filters.retina.GanglionSpiking : creates spikes from an input current """ def __init__(self,opl=True,bipolar=True,gang=True,spikes=True): self.keep_timing_info = False self.timing_info = [] super(Retina,self).__init__() if opl is True: self.opl = rf.OPL() else: self.opl = opl if bipolar is True: self.bipolar = rf.Bipolar() else: self.bipolar = bipolar self.gang_0_input = rf.GanglionInput() self.gang_0_spikes = rf.GanglionSpiking() self.gang_1_input = rf.GanglionInput() self.gang_1_spikes = rf.GanglionSpiking() def add(x,y): return x+y # Ix = f(Iy,Iz,...) self.commands = [] if opl: self.commands.append((['I1'], self.opl, ['I1'])) if bipolar: self.commands.append((['I1','I2'], self.bipolar, ['I1'])) else: self.commands.append((['I1','I2'], 'copy', ['I1'])) if gang: self.commands.append((['I1'], self.gang_0_input, ['I1'])) if spikes: self.commands.append((['I1'], self.gang_0_spikes, ['I1'])) if gang: self.commands.append((['I2'], self.gang_1_input, ['I2'])) if spikes: self.commands.append((['I2'], self.gang_1_spikes, ['I2'])) conf = RetinaConfiguration() self.parse_config(conf) def cuda(self, *args, **kwargs): # for now, the modules are not collected! super(Retina,self).cuda(*args, **kwargs) self.opl.opl_filter.cuda(*args, **kwargs) self.opl.opl_filter.center_E.cuda(*args, **kwargs) self.opl.opl_filter.surround_E.cuda(*args, **kwargs) self.opl.opl_filter.center_G.cuda(*args, **kwargs) self.opl.opl_filter.surround_G.cuda(*args, **kwargs) self.opl.opl_filter.center_undershoot.cuda(*args, **kwargs) self.bipolar.cuda(*args, **kwargs) self.gang_0_input.cuda(*args, **kwargs) self.gang_0_spikes.cuda(*args, **kwargs) self.gang_1_input.cuda(*args, **kwargs) self.gang_1_spikes.cuda(*args, **kwargs) def parse_config(self,config,prefix='',key='retina_config_key'): if type(config) is str: config_file = config config = RetinaConfiguration() config.read_xml(config_file) if hasattr(self.opl,'parse_config'): self.opl.parse_config(config,prefix='outer-plexiform-layers.0.linear-version.',key=key) if hasattr(self.bipolar,'parse_config'): self.bipolar.parse_config(config,prefix='contrast-gain-control.',key=key) if hasattr(self.gang_0_input,'parse_config'): self.gang_0_input.parse_config(config,prefix='ganglion-layers.0.',key=key) if hasattr(self.gang_0_spikes,'parse_config'): self.gang_0_spikes.parse_config(config,prefix='ganglion-layers.0.spiking-channel.',key=key) if hasattr(self.gang_1_input,'parse_config'): self.gang_1_input.parse_config(config,prefix='ganglion-layers.1.',key=key) if hasattr(self.gang_1_spikes,'parse_config'): self.gang_1_spikes.parse_config(config,prefix='ganglion-layers.1.spiking-channel.',key=key) def forward(self,inp): self._timing = [] import datetime io_buffers = {'I1':inp} for b_out,f,b_in in self.commands: start_time = datetime.datetime.now() if f =='copy': for oi,oo in enumerate(b_out): io_buffers[oo] = io_buffers[b_in[0]] else: #print b_in, f, b_out o = f(*[io_buffers[i] for i in b_in]) if type(o) is Output: o = o[0] # we can only use the first output for oi,oo in enumerate(b_out): io_buffers[oo] = o self._timing.append((f,(datetime.datetime.now()-start_time).total_seconds())) if self.keep_timing_info: self.timing_info.append(self._timing) return Output([io_buffers['I1'],io_buffers['I2']],keys=['ganglion_spikes_ON','ganglion_spikes_OFF'])