From a36629c963d7156c23fb878c5e77c63301260a60 Mon Sep 17 00:00:00 2001 From: Andrew Davison Date: Thu, 28 Mar 2019 11:57:52 +0100 Subject: [PATCH] In sonata module, fully implement filtering in node sets based on node population attributes; support step current injection --- pyNN/common/populations.py | 20 ++++ pyNN/network.py | 37 +++++-- pyNN/serialization/sonata.py | 187 ++++++++++++++++++++++++----------- 3 files changed, 179 insertions(+), 65 deletions(-) diff --git a/pyNN/common/populations.py b/pyNN/common/populations.py index 202d74e25..398bac925 100644 --- a/pyNN/common/populations.py +++ b/pyNN/common/populations.py @@ -1451,3 +1451,23 @@ def describe(self, template='assembly_default.txt', engine='default'): context = {"label": self.label, "populations": [p.describe(template=None) for p in self.populations]} return descriptions.render(engine, template, context) + + def get_annotations(self, annotation_keys, simplify=True): + """ + Get the values of the given annotations for each population in the Assembly. + """ + if isinstance(annotation_keys, basestring): + annotation_keys = (annotation_keys,) + annotations = defaultdict(list) + + for key in annotation_keys: + is_array_annotation = False + for p in self.populations: + annotation = p.annotations[key] + annotations[key].append(annotation) + is_array_annotation = isinstance(annotation, numpy.ndarray) + if is_array_annotation: + annotations[key] = numpy.hstack(annotations[key]) + if simplify: + annotations[key] = simplify_parameter_array(numpy.array(annotations[key])) + return annotations diff --git a/pyNN/network.py b/pyNN/network.py index 570a401d7..5c79e1b5b 100644 --- a/pyNN/network.py +++ b/pyNN/network.py @@ -3,6 +3,8 @@ """ +import sys +import inspect from itertools import chain try: basestring @@ -22,6 +24,23 @@ def __init__(self, *components): self.views = set([]) self.assemblies = set([]) self.projections = set([]) + self.add(*components) + + @property + def sim(self): + """Figure out which PyNN backend module this Network is using.""" + # we assume there is only one. Could be mixed if using multiple simulators + # at once. + populations_module = inspect.getmodule(list(self.populations)[0].__class__) + return sys.modules[".".join(populations_module.__name__.split(".")[:-1])] + + def count_neurons(self): + return sum(population.size for population in chain(self.populations)) + + def count_connections(self): + return sum(projection.size() for projection in chain(self.projections)) + + def add(self, *components): for component in components: if isinstance(component, Population): self.populations.add(component) @@ -37,18 +56,24 @@ def __init__(self, *components): else: raise TypeError() - def count_neurons(self): - return sum(population.size for population in chain(self.populations)) - - def count_connections(self): - return sum(projection.size() for projection in chain(self.projections)) - def get_component(self, label): for obj in chain(self.populations, self.views, self.assemblies, self.projections): if obj.label == label: return obj return None + def filter(self, cell_types=None): + """Return an Assembly of all components that have a cell type in the list""" + if cell_types is None: + raise NotImplementedError() + else: + if cell_types == "all": + return self.sim.Assembly(*(pop for pop in self.populations + if pop.celltype.injectable)) # or could use len(receptor_types) > 0 + else: + return self.sim.Assembly(*(pop for pop in self.populations + if pop.celltype.__class__ in cell_types)) + def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True): for obj in chain(self.populations, self.assemblies): if include_spike_source or obj.injectable: # spike sources are not injectable diff --git a/pyNN/serialization/sonata.py b/pyNN/serialization/sonata.py index 99eb4ce6d..a3cab5d35 100644 --- a/pyNN/serialization/sonata.py +++ b/pyNN/serialization/sonata.py @@ -105,6 +105,8 @@ def write(self, blocks): Write a list of Blocks to SONATA HDF5 files. """ + if not os.path.isdir(self.base_dir): + os.makedirs(self.base_dir) # Write spikes spike_file_path = join(self.base_dir, self.spike_file) spikes_file = h5py.File(spike_file_path, 'w') @@ -131,36 +133,38 @@ def write(self, blocks): file_path = join(self.base_dir, file_name) signal_file = h5py.File(file_path, 'w') - population_name = self.node_sets[report_metadata["cells"]]["population"] - node_ids = self.node_sets[report_metadata["cells"]]["node_id"] + targets = self.node_sets[report_metadata["cells"]] for block in blocks: - if block.name == population_name: - if len(block.segments) > 1: - raise NotImplementedError() - signal = block.segments[0].filter(name=report_metadata["variable_name"]) - if len(signal) != 1: - raise NotImplementedError() - - report_group = signal_file.create_group("report") - population_group = report_group.create_group(population_name) - dataset = population_group.create_dataset("data", data=signal[0].magnitude) - dataset.attrs["units"] = signal[0].units.dimensionality.string - dataset.attrs["variable_name"] = report_metadata["variable_name"] - n = dataset.shape[1] - mapping_group = population_group.create_group("mapping") - mapping_group.create_dataset("node_ids", data=node_ids) - # "gids" not in the spec, but expected by some bmtk utils - mapping_group.create_dataset("gids", data=node_ids) - #mapping_group.create_dataset("index_pointers", data=np.zeros((n,))) - mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear - mapping_group.create_dataset("element_ids", data=np.zeros((n,))) - mapping_group.create_dataset("element_pos", data=np.zeros((n,))) - time_ds = mapping_group.create_dataset("time", - data=(float(signal[0].t_start), - float(signal[0].t_stop), - float(signal[0].sampling_period))) - time_ds.attrs["units"] = "ms" - logger.info("Wrote block {} to {}".format(block.name, file_path)) + for (assembly, mask) in targets: + if block.name == assembly.label: + if len(block.segments) > 1: + raise NotImplementedError() + signal = block.segments[0].filter(name=report_metadata["variable_name"]) + if len(signal) != 1: + raise NotImplementedError() + + node_ids = np.arange(assembly.size)[mask] + + report_group = signal_file.create_group("report") + population_group = report_group.create_group(assembly.label) + dataset = population_group.create_dataset("data", data=signal[0].magnitude) + dataset.attrs["units"] = signal[0].units.dimensionality.string + dataset.attrs["variable_name"] = report_metadata["variable_name"] + n = dataset.shape[1] + mapping_group = population_group.create_group("mapping") + mapping_group.create_dataset("node_ids", data=node_ids) + # "gids" not in the spec, but expected by some bmtk utils + mapping_group.create_dataset("gids", data=node_ids) + #mapping_group.create_dataset("index_pointers", data=np.zeros((n,))) + mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear + mapping_group.create_dataset("element_ids", data=np.zeros((n,))) + mapping_group.create_dataset("element_pos", data=np.zeros((n,))) + time_ds = mapping_group.create_dataset("time", + data=(float(signal[0].t_start.rescale('ms')), + float(signal[0].t_stop.rescale('ms')), + float(signal[0].sampling_period.rescale('ms')))) + time_ds.attrs["units"] = "ms" + logger.info("Wrote block {} to {}".format(block.name, file_path)) signal_file.close() @@ -232,6 +236,7 @@ def condense(value, types_array): from "/edges//edge_type_id" that applies to this group. Needed to construct parameter arrays. """ + # todo: use lazyarray if isinstance(value, np.ndarray): return value elif isinstance(value, dict): @@ -240,7 +245,12 @@ def condense(value, types_array): if np.all(value_array == value_array[0]): return value_array[0] else: - new_value = np.ones_like(types_array) * np.nan + if np.issubdtype(value_array.dtype, np.number): + new_value = np.ones_like(types_array) * np.nan + elif np.issubdtype(value_array.dtype, np.str_): + new_value = np.array(["UNDEFINED"] * types_array.size) + else: + raise TypeError("Cannot handle annotations that are neither numbers or strings") for node_type_id, val in value.items(): new_value[types_array == node_type_id] = val return new_value @@ -584,10 +594,10 @@ def import_from_sonata(config_file, sim): net = Network() for node_population in sonata_node_populations: assembly = node_population.to_assembly(sim) - net.assemblies.add(assembly) + net.add(assembly) for edge_population in sonata_edge_populations: projections = edge_population.to_projections(net, sim) - net.projections.update(projections) + net.add(*projections) return net @@ -777,7 +787,7 @@ def to_population(self, sim): if name in cell_type_cls.default_parameters: parameters[name] = condense(value, self.node_types_array) else: - annotations[name] = value + annotations[name] = condense(value, self.node_types_array) # todo: handle spatial structure - nodes_file["nodes"][np_label][ng_label]['x'], etc. # temporary hack to work around problem with 300 Intfire cell example @@ -1072,28 +1082,21 @@ def setup(self, sim): self.sim = sim sim.setup(timestep=self.run_config["dt"]) - def _get_target(self, config, node_sets, net): + def _get_target(self, config, net): if "node_set" in config: # input config - target = node_sets[config["node_set"]] - elif "cells" in config: # recording config + targets = self.node_set_map[config["node_set"]] + elif "cells" in config: # recording config # inconsistency in SONATA spec? Why not call this "node_set" also? - target = node_sets[config["cells"]] - if "model_type" in target: - raise NotImplementedError() - if "location" in target: - raise NotImplementedError() - if "gids" in target: - raise NotImplementedError() - if "population" in target: - assembly = net.get_component(target["population"]) - if "node_id" in target: - indices = target["node_id"] - assembly = assembly[indices] - return assembly + targets = self.node_set_map[config["cells"]] + return targets - def _set_input_spikes(self, input_config, node_sets, net): + def _set_input_spikes(self, input_config, net): # determine which assembly the spikes are for - assembly = self._get_target(input_config, node_sets, net) + targets = self._get_target(input_config, net) + if len(targets) != 1: + raise NotImplementedError() + base_assembly, mask = targets[0] + assembly = base_assembly[mask] assert isinstance(assembly, self.sim.Assembly) # load spike data from file @@ -1111,22 +1114,88 @@ def _set_input_spikes(self, input_config, node_sets, net): if len(spiketrains) != assembly.size: raise NotImplementedError() # todo: map cell ids in spikes file to ids/index in the population - #logger.info("SETTING SPIKETIMES") - #logger.info(spiketrains) assembly.set(spike_times=[Sequence(st.times.rescale('ms').magnitude) for st in spiketrains]) + def _set_input_currents(self, input_config, net): + # determine which assembly the currents are for + if "input_file" in input_config: + raise NotImplementedError("Current clamp from source file not yet supported.") + targets = self._get_target(input_config, net) + if len(targets) != 1: + raise NotImplementedError() + base_assembly, mask = targets[0] + assembly = base_assembly[mask] + assert isinstance(assembly, self.sim.Assembly) + amplitude = input_config["amp"] # nA + if self.target_simulator == "NEST": + amplitude = input_config["amp"]/1000.0 # pA + + current_source = self.sim.DCSource(amplitude=amplitude, + start=input_config["delay"], + stop=input_config["delay"] + input_config["duration"]) + assembly.inject(current_source) + + def _calculate_node_set_map(self, net): + # for each "node set" in the config, determine which populations + # and node_ids it corresponds to + self.node_set_map = {} + + # first handle implicit node sets - i.e. each node population is an implicit node set + for assembly in net.assemblies: + self.node_set_map[assembly.label] = [(assembly, slice(None))] + + # now handle explictly-declared node sets + # todo: handle compound node sets + for node_set_name, node_set_definition in self.node_sets.items(): + if isinstance(node_set_definition, dict): # basic node set + filters = node_set_definition + if "population" in filters: + assemblies = [net.get_component(filters["population"])] + else: + assemblies = list(net.assemblies) + + self.node_set_map[node_set_name] = [] + for assembly in assemblies: + mask = True + for attr_name, attr_value in filters.items(): + print(attr_name, attr_value, "____") + if attr_name == "population": + continue + elif attr_name == "node_id": + # convert integer mask to boolean mask + node_mask = np.zeros(assembly.size, dtype=bool) + node_mask[attr_value] = True + mask = np.logical_and(mask, node_mask) + else: + values = assembly.get_annotations(attr_name)[attr_name] + mask = np.logical_and(mask, values == attr_value) + if isinstance(mask, (bool, np.bool_)) and mask == True: + mask = slice(None) + self.node_set_map[node_set_name].append((assembly, mask)) + elif isinstance(node_set_definition, list): # compound node set + raise NotImplementedError("Compound node sets not yet supported") + else: + raise TypeError("Expecting node set definition to be a list or dict") + def execute(self, net): + self._calculate_node_set_map(net) + # create/configure inputs for input_name, input_config in self.inputs.items(): - if input_config["input_type"] != "spikes": - raise NotImplementedError() - self._set_input_spikes(input_config, self.node_sets, net) + if input_config["input_type"] == "spikes": + self._set_input_spikes(input_config, net) + elif input_config["input_type"] == "current_clamp": + self._set_input_currents(input_config, net) + else: + raise NotImplementedError("Only 'spikes' and 'current_clamp' supported") # configure recording net.record('spikes', include_spike_source=False) # SONATA requires that we record spikes from all non-virtual nodes for report_name, report_config in self.reports.items(): - assembly = self._get_target(report_config, self.node_sets, net) - assembly.record(report_config["variable_name"]) + targets = self._get_target(report_config, net) + for (base_assembly, mask) in targets: + assembly = base_assembly[mask] + assembly.record(report_config["variable_name"]) # run simulation self.sim.run(self.run_config["tstop"]) @@ -1141,7 +1210,7 @@ def execute(self, net): spikes_file=self.output.get("spikes_file", "spikes.h5"), spikes_sort_order=self.output["spikes_sort_order"], report_config=self.reports, - node_sets=self.node_sets) + node_sets=self.node_set_map) # todo: handle reports net.write_data(io)