Skip to content

Commit

Permalink
Merge pull request kwikteam#71 from rossant/labels
Browse files Browse the repository at this point in the history
Load and save labels in CSV files
  • Loading branch information
rossant authored Oct 27, 2016
2 parents 6a09199 + 8cab1c7 commit 6d5d1e6
Showing 1 changed file with 54 additions and 26 deletions.
80 changes: 54 additions & 26 deletions phycontrib/template/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#------------------------------------------------------------------------------

import csv
import glob
import logging
from operator import itemgetter
import os
Expand Down Expand Up @@ -205,6 +206,36 @@ def _densify(rows, arr, ind, ncols):
return out


def load_metadata(filename, cluster_ids):
dic = {cluster_id: None for cluster_id in cluster_ids}
if not op.exists(filename):
return dic
with open(filename, 'r') as f:
reader = csv.reader(f, delimiter='\t')
# Skip the header.
for row in reader:
break
for row in reader:
cluster, value = row
cluster = int(cluster)
dic[cluster] = value
return dic


def save_metadata(filename, field_name, metadata):
"""Save metadata in a CSV file."""
import sys
if sys.version_info[0] < 3:
file = open(filename, 'wb')
else:
file = open(filename, 'w', newline='')
with file as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(['cluster_id', field_name])
writer.writerows([(cluster, metadata[cluster])
for cluster in sorted(metadata)])


class TemplateController(Controller):
gui_name = 'TemplateGUI'

Expand Down Expand Up @@ -420,20 +451,23 @@ def _init_data(self):

# Read the cluster groups.
logger.debug("Loading the cluster groups.")
self.cluster_groups = {}
if op.exists(filenames['cluster_groups']):
with open(filenames['cluster_groups'], 'r') as f:
reader = csv.reader(f, delimiter='\t')
# Skip the header.
for row in reader:
break
for row in reader:
cluster, group = row
cluster = int(cluster)
self.cluster_groups[cluster] = group
for cluster_id in self.cluster_ids:
if cluster_id not in self.cluster_groups:
self.cluster_groups[cluster_id] = None
self.cluster_groups = load_metadata(filenames['cluster_groups'],
self.cluster_ids)

def _set_manual_clustering(self):
super(TemplateController, self)._set_manual_clustering()
mc = self.manual_clustering
# Load labels.
files = glob.glob('*.csv')
for filename in files:
if filename == 'cluster_groups.csv':
continue
field_name = op.basename(filename)
values = load_metadata(filename, self.cluster_ids)
for cluster_id in self.cluster_ids:
mc.cluster_meta.set(field_name, [cluster_id],
values.get(cluster_id, None),
add_to_stack=False)

def get_cluster_templates(self, cluster_id):
spike_ids = self.spikes_per_cluster(cluster_id)
Expand Down Expand Up @@ -716,20 +750,14 @@ def toggle_waveforms():

# Save.
@gui.connect_
def on_request_save(spike_clusters, groups):
def on_request_save(spike_clusters, groups, *labels):
# Save the clusters.
np.save(filenames['spike_clusters'], spike_clusters)
# Save the cluster groups.
import sys
if sys.version_info[0] < 3:
file = open(filenames['cluster_groups'], 'wb')
else:
file = open(filenames['cluster_groups'], 'w', newline='')
with file as f:
writer = csv.writer(f, delimiter='\t')
writer.writerow(['cluster_id', 'group'])
writer.writerows([(cluster, groups[cluster])
for cluster in sorted(groups)])
# Save cluster groups.
save_metadata(filenames['cluster_groups'], 'group', groups)
# Save other labels.
for field_name, dic in labels:
save_metadata('cluster_%s.csv' % field_name, field_name, dic)

# Save the memcache when closing the GUI.
@gui.connect_
Expand Down

0 comments on commit 6d5d1e6

Please sign in to comment.