diff --git a/pylearn2/train_extensions/live_monitoring.py b/pylearn2/train_extensions/live_monitoring.py index ea4a6e15fd..7ee417e597 100644 --- a/pylearn2/train_extensions/live_monitoring.py +++ b/pylearn2/train_extensions/live_monitoring.py @@ -2,26 +2,53 @@ Training extension for allowing querying of monitoring values while an experiment executes. """ -__authors__ = "Dustin Webb" +__authors__ = "Dustin Webb, Adam Stone, Nicu Tofan" __copyright__ = "Copyright 2010-2012, Universite de Montreal" -__credits__ = ["Dustin Webb"] +__credits__ = ["Dustin Webb", "Adam Stone"] __license__ = "3-clause BSD" __maintainer__ = "LISA Lab" __email__ = "pylearn-dev@googlegroups" import copy +import logging +LOG = logging.getLogger(__name__) try: import zmq - zmq_available = True -except: - zmq_available = False + ZMQ_AVAILABLE = True +except Exception: + ZMQ_AVAILABLE = False + +try: + from PySide import QtCore, QtGui + + import matplotlib + import numpy as np + matplotlib.use('Qt4Agg') + matplotlib.rcParams['backend.qt4'] = 'PySide' + + from matplotlib.backends.backend_qt4agg import ( + FigureCanvasQTAgg as FigureCanvas, + NavigationToolbar2QT as NavigationToolbar) + from matplotlib.figure import Figure + + QT_AVAILABLE = True +except Exception: + QT_AVAILABLE = False try: import matplotlib.pyplot as plt - pyplot_available = True -except: - pyplot_available = False + PYPLOT_AVAILABLE = True +except ImportError: + PYPLOT_AVAILABLE = False + +MPLDC_AVAILABLE = False +try: + if PYPLOT_AVAILABLE: + import mpldatacursor as mpldc + MPLDC_AVAILABLE = True +except ImportError: + pass from functools import wraps from pylearn2.monitor import Monitor @@ -86,14 +113,17 @@ def __init__(self, channel_list, start, end, step=1): ) self.channel_list = channel_list - assert(start >= 0) + assert start >= 0 self.start = start self.end = end - assert(step > 0) + assert step > 0 self.step = step + # This is the payload and it may be a Throwable object or + # the actual result (exact type depends on the message type) + self.data = {} class ChannelsRequest(LiveMonitorMsg): """ @@ -120,12 +150,12 @@ def __init__(self, channel_list, start=0, end=-1, step=1): ) self.channel_list = channel_list - assert(start >= 0) + assert start >= 0 self.start = start self.end = end - assert(step > 0) + assert step > 0 self.step = step @wraps(LiveMonitorMsg.get_response) @@ -144,30 +174,35 @@ class LiveMonitoring(TrainExtension): being monitored in real time. PyZMQ must be installed for this extension to work. + A LiveMonitoring that has no connected subscribers will simply + drop all messages. + Parameters ---------- address : string The IP addresses of the interfaces on which the monitor should listen. req_port : int - The port number to be used to service request. + The port number to be used to service request. To disable the + request-response variant pass 0 to this parameter. pub_port : int - The port number to be used to publish updates. + The port number to be used to publish updates. To disable the + publish-subscribe variant pass 0 to this parameter. """ def __init__(self, address='*', req_port=5555, pub_port=5556): - if not zmq_available: + if not ZMQ_AVAILABLE: raise ImportError('zeromq needs to be installed to ' 'use this module.') self.address = 'tcp://%s' % address - assert(req_port != pub_port) + assert req_port != pub_port - assert(req_port > 1024 and req_port < 65536) + assert req_port > 1024 and req_port < 65536 self.req_port = req_port - assert(pub_port > 1024 and pub_port < 65536) + assert pub_port > 1024 and pub_port < 65536 self.pub_port = pub_port address_template = self.address + ':%d' @@ -181,14 +216,93 @@ def __init__(self, address='*', req_port=5555, pub_port=5556): self.pub_sock = None if self.pub_port > 0: self.pub_sock = self.context.socket(zmq.PUB) - self.req_sock.bind(address_template % self.pub_port) + self.pub_sock.bind(address_template % self.pub_port) # Tracks the number of times on_monitor has been called self.counter = 0 - @wraps(TrainExtension.on_monitor) - def on_monitor(self, model, dataset, algorithm): - monitor = Monitor.get_monitor(model) + # number of entries to be published at a time; + # will be initialized the first time a message is about to be + # published in __build_channel_resp__() + self.post_size = 0 + + def __build_channel_resp__(self, monitor, channel_list, + start=0, end=-1, step=1): + """ + Constructs a response or publish message containing channel data. + + The message will either be an Throwable or a dictionary, with + keys being the names of the channels. + + Individual entries for each channel may also be a Throwableinstance + or channel data (see source for actual content). + + Parameters + ---------- + monitor : Monitor + Model's monitor from where we are about to extract the data + + channel_list : list + A list of the channels for which data is needed. + + start : int + The starting epoch for which data should be returned. + + end : int + The epoch after which data should be returned. + + step : int + The number of epochs to be skipped between data points. + """ + result = {} + if not isinstance(channel_list, list) or len(channel_list) == 0: + channel_list = [] + result = TypeError('ChannelResponse requires a list of channels.') + else: + for channel_name in channel_list: + if channel_name in monitor.channels.keys(): + chan = copy.deepcopy( + monitor.channels[channel_name] + ) + if self.post_size == 0: + self.post_size = len(chan.batch_record) + if end == -1: + end = len(chan.batch_record) + # TODO copying and truncating the records individually + # like this is brittle. Is there a more robust + # solution? + chan.batch_record = chan.batch_record[ + start:end:step + ] + chan.epoch_record = chan.epoch_record[ + start:end:step + ] + chan.example_record = chan.example_record[ + start:end:step + ] + chan.time_record = chan.time_record[ + start:end:step + ] + chan.val_record = chan.val_record[ + start:end:step + ] + result[channel_name] = chan + else: + result[channel_name] = KeyError( + 'Invalid channel: %s' % channel_name + ) + return result + + def __reply_to_req__(self, monitor): + """ + Replies to a request for specific channels or to list all channels. + + Parameters + ---------- + monitor : Monitor + Model's monitor from where we are about to extract the data + + """ try: rsqt_msg = self.req_sock.recv_pyobj(flags=zmq.NOBLOCK) @@ -198,63 +312,59 @@ def on_monitor(self, model, dataset, algorithm): if isinstance(rsp_msg, ChannelListResponse): rsp_msg.data = list(monitor.channels.keys()) - if isinstance(rsp_msg, ChannelsResponse): + elif isinstance(rsp_msg, ChannelsResponse): channel_list = rsp_msg.channel_list - if ( - not isinstance(channel_list, list) - or len(channel_list) == 0 - ): - channel_list = [] - result = TypeError( - 'ChannelResponse requires a list of channels.' - ) - - result = {} - for channel_name in channel_list: - if channel_name in monitor.channels.keys(): - chan = copy.deepcopy( - monitor.channels[channel_name] - ) - end = rsp_msg.end - if end == -1: - end = len(chan.batch_record) - # TODO copying and truncating the records individually - # like this is brittle. Is there a more robust - # solution? - chan.batch_record = chan.batch_record[ - rsp_msg.start:end:rsp_msg.step - ] - chan.epoch_record = chan.epoch_record[ - rsp_msg.start:end:rsp_msg.step - ] - chan.example_record = chan.example_record[ - rsp_msg.start:end:rsp_msg.step - ] - chan.time_record = chan.time_record[ - rsp_msg.start:end:rsp_msg.step - ] - chan.val_record = chan.val_record[ - rsp_msg.start:end:rsp_msg.step - ] - result[channel_name] = chan - else: - result[channel_name] = KeyError( - 'Invalid channel: %s' % rsp_msg.channel - ) - rsp_msg.data = result - + rsp_msg.data = self.__build_channel_resp__(monitor, + channel_list, + rsp_msg.start, + rsp_msg.end, + rsp_msg.step) self.req_sock.send_pyobj(rsp_msg) except zmq.Again: pass - self.counter += 1 + def __publish_results__(self, monitor): + """ + Publishes all channels to dedicated ZMQ slot. + Parameters + ---------- + monitor : Monitor + Model's monitor from where we are about to extract the data + """ + if self.pub_sock is None: + return + + try: + channel_list = list(monitor.channels.keys()) + start = self.counter*self.post_size + end = -1 if self.post_size == 0 else start + self.post_size + rsp_msg = ChannelsResponse(channel_list, start, end, step=1) + rsp_msg.data = self.__build_channel_resp__(monitor, + channel_list, + start, end) + self.pub_sock.send_pyobj(rsp_msg) + except Exception, ex: + LOG.warn("Exception while publishing results in LiveMonitoring:" + + ex.message) + + @wraps(TrainExtension.on_monitor) + def on_monitor(self, model, dataset, algorithm): + monitor = Monitor.get_monitor(model) + if self.req_port > 0: + self.__reply_to_req__(monitor) + if self.pub_port: + self.__publish_results__(monitor) + self.counter += 1 class LiveMonitor(object): """ A utility class for requested data from a LiveMonitoring training extension. + On the publish-subscribe variant, please note that, if the subscriber + is slower than the publisher, the messages will pile up on the publisher. + Parameters ---------- address : string @@ -262,32 +372,63 @@ class LiveMonitor(object): req_port : int The port number on which a LiveMonitoring process is listening. + + subscribe : bool + Use publish-subscribe variant (True) or request-reply (False, default). """ - def __init__(self, address='127.0.0.1', req_port=5555): + def __init__(self, address='127.0.0.1', req_port=5555, subscribe=False): """ """ - if not zmq_available: + if not ZMQ_AVAILABLE: raise ImportError('zeromq needs to be installed to ' 'use this module.') self.address = 'tcp://%s' % address - assert(req_port > 0) + assert req_port > 0 self.req_port = req_port + self.subscribe = subscribe self.context = zmq.Context() - self.req_sock = self.context.socket(zmq.REQ) + if subscribe: + self.req_sock = self.context.socket(zmq.SUB) + self.req_sock.setsockopt(zmq.SUBSCRIBE, "") + else: + self.req_sock = self.context.socket(zmq.REQ) + self.req_sock.connect(self.address + ':' + str(self.req_port)) + # A dictionary that has the names of the channels as keys and + # channel data reported by LiveMonitoring as values. self.channels = {} - def list_channels(self): + def list_channels(self, cached=False): """ Returns a list of the channels being monitored. + + Parameters + ---------- + cached : bool + If a cached version exists, return that instead of sending a + new request. """ - self.req_sock.send_pyobj(ChannelListRequest()) - return self.req_sock.recv_pyobj() + if cached: + if len(self.channels) > 0: + return self.channels.keys() + + if self.subscribe: + # we could create a new socket and send a request here + LOG.warn('Subscribe variant of LiveMonitor is only capable ' + 'of returning cached list of channels. ' + 'Use list_channels(cached=True) to avoid this warning.') + if len(self.channels) > 0: + return self.channels.keys() + else: + return [] + else: + self.req_sock.send_pyobj(ChannelListRequest()) + return self.req_sock.recv_pyobj() def update_channels(self, channel_list, start=-1, end=-1, step=1): """ @@ -312,17 +453,25 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1): step : int The number of epochs to be skipped between data points. """ - assert((start == -1 and end == -1) or end > start) - - if start == -1: - start = 0 - if len(self.channels.keys()) > 0: - channel_name = list(self.channels.keys())[0] - start = len(self.channels[channel_name].epoch_record) - - self.req_sock.send_pyobj(ChannelsRequest( - channel_list, start=start, end=end, step=step - )) + assert (start == -1 and end == -1) or end > start + + if self.subscribe: + if start != -1 or end != -1 or step != 1: + LOG.warn('Subscribe variant of LiveMonitor is only capable ' + 'of retreiving last result.' + 'Use update_channels(channel_list) to avoid this warning.') + # then again, we could see if we already have that range cached + # locally and only throw the warning otherwise + else: + if start == -1: + start = 0 + if len(self.channels.keys()) > 0: + channel_name = list(self.channels.keys())[0] + start = len(self.channels[channel_name].epoch_record) + + self.req_sock.send_pyobj(ChannelsRequest( + channel_list, start=start, end=end, step=step + )) rsp_msg = self.req_sock.recv_pyobj() @@ -335,32 +484,57 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1): if isinstance(rsp_chan, Exception): raise rsp_chan + if self.subscribe: + if channel not in channel_list: + continue + if channel not in self.channels.keys(): self.channels[channel] = rsp_chan else: chan = self.channels[channel] + + len_batch_rec = len(rsp_chan.batch_record) + assert len_batch_rec == len(rsp_chan.epoch_record) + assert len_batch_rec == len(rsp_chan.example_record) + assert len_batch_rec == len(rsp_chan.time_record) + assert len_batch_rec == len(rsp_chan.val_record) + chan.batch_record += rsp_chan.batch_record chan.epoch_record += rsp_chan.epoch_record chan.example_record += rsp_chan.example_record chan.time_record += rsp_chan.time_record chan.val_record += rsp_chan.val_record - def follow_channels(self, channel_list): + def follow_channels(self, channel_list, use_qt=False): """ Tracks and plots a specified set of channels in real time. Parameters ---------- - channel_list : list - A list of the channels for which data has been requested. + channel_list : list or dict + A list of the channels for which data will be requested an plotted + or a dictionary where keys will become the names of the plots while + values are lists of channel names. + use_qt : bool + Use a PySide GUI for plotting, if available. """ - if not pyplot_available: + if use_qt: + self.__qt_follow__(channel_list) + + elif not PYPLOT_AVAILABLE: raise ImportError('pyplot needs to be installed for ' 'this functionality.') + else: + self.__ion_follow__(channel_list) + + def __ion_follow__(self, channel_list): + """ + follow_channels() implementation using ion(). + """ plt.clf() plt.ion() while True: - self.update_channel(channel_list) + self.update_channels(channel_list) plt.clf() for channel_name in self.channels: plt.plot( @@ -371,3 +545,170 @@ def follow_channels(self, channel_list): plt.legend() plt.ion() plt.draw() + + def __qt_follow__(self, channel_list): + """ + follow_channels() implementation using Qt. + """ + if not QT_AVAILABLE: + LOG.warning( + 'follow_channels called with use_qt=True, but PySide ' + 'is not available. Falling back on matplotlib ion().') + self.__ion_follow__(channel_list) + else: + # only create new qt app if running the first time in session + + if isinstance(channel_list, dict): + self.channel_dict = channel_list + tmp_list = [] + for k in channel_list: + tmp_list.extend(channel_list[k]) + channel_list = tmp_list + + # remove duplicates in the list of channels + self.channel_list = list(set(tmp_list)) + else: + self.channel_list = channel_list + self.channel_dict = {'': channel_list} + + if len(self.channel_list) == 0: + raise ValueError('No channel name provided; ' + 'channel_list must be either ' + 'a list or a dict') + + if not hasattr(self, 'gui'): + self.gui = LiveMonitorGUI(self, + self.channel_list, + self.channel_dict) + + self.gui.start() + +if QT_AVAILABLE: + + class LiveMonitorGUI(QtGui.QMainWindow): + """ + PySide GUI implementation for live monitoring channels. + + Parameters + ---------- + live_mon : LiveMonitor instance + The LiveMonitor instance to which the GUI belongs. + + channel_list : list + A list of the channels to display. + """ + def __init__(self, live_mon, channel_list, channel_dict): + + self.app = QtGui.QApplication(["Live Monitor"]) + + super(LiveMonitorGUI, self).__init__() + self.live_mon = live_mon + self.channel_list = channel_list + self.channel_dict = channel_dict + self.updater_thread = UpdaterThread(live_mon, channel_list) + self.updater_thread.updated.connect(self.__refresh__) + self.__init_ui__() + + def __common_ui__(self): + if MPLDC_AVAILABLE: + opts = {'hover': True, + 'xytext':(15, -30), + 'formatter':"{label} {y:0.3g}\nat epoch {x:0.0f}".format, + 'keybindings':{'hide':'h', 'toggle':'e'}, + 'bbox':{'fc':'white'}, + 'arrowprops': {'arrowstyle':'simple', + 'fc':'white', + 'alpha':0.1}} + #draggable=True + mpldc.datacursor(axes=self.fig.axes, **opts) + + def __init_ui__(self): + matplotlib.rcParams.update({'font.size': 8}) + self.resize(600, 400) + self.fig = Figure(figsize=(600, 400), dpi=72, + facecolor=(1, 1, 1), edgecolor=(0, 0, 0)) + + arrange = {1: [1, 1], 2: [1, 2], 3: [1, 3], 4: [2, 2], + 5: [2, 3], 6: [2, 3], 7: [2, 4], 8: [2, 4], + 9: [3, 3], 10: [3, 4], 11: [3, 4], 12: [3, 4]} + splot_len = len(self.channel_dict) + if splot_len < 13: + splot_layout = arrange[splot_len] + else: + splot_layout = [splot_len//5, 5] + + self.ax = [] + for splot_i in enumerate(len(self.channel_dict)): + self.ax.append(self.fig.add_subplot(splot_layout[0], + splot_layout[1], + splot_i+1)) + + self.fig.subplots_adjust(left=0.02, right=0.98, + top=0.98, bottom=0.02, + hspace=0.1) + self.__common_ui__() + self.canvas = FigureCanvas(self.fig) + self.setCentralWidget(self.canvas) + ntb = NavigationToolbar(self.canvas, self) + self.addToolBar(ntb) + + def __refresh__(self): + if not self.live_mon.channels: + self.updater_thread.start() + return + + splot_i = 0 + for splot_name in self.channel_dict: + self.ax[splot_i].cla() # clear previous plot + chan_list = self.channel_dict[splot_name] + + for channel_name in chan_list: + if not channel_name in self.live_mon.channels: + splot_i = splot_i + 1 + continue + + X = epoch_record = self.live_mon.channels[channel_name].epoch_record + Y = val_record = self.live_mon.channels[channel_name].val_record + + indices = np.nonzero(np.diff(epoch_record))[0] + 1 + epoch_record_split = np.split(epoch_record, indices) + val_record_split = np.split(val_record, indices) + + X = np.zeros(len(epoch_record)) + Y = np.zeros(len(epoch_record)) + + for i, epoch in enumerate(epoch_record_split): + + j = i*len(epoch_record_split[0]) + X[j: j + len(epoch)] = ( + 1.*np.arange(len(epoch)) / len(epoch) + epoch[0]) + Y[j: j + len(epoch)] = val_record_split[i] + + self.ax[splot_i].plot(X, Y, label=channel_name) + self.ax[splot_i].legend(loc='best', fancybox=True, framealpha=0.5) + self.fig.axes[splot_i].set_xlabel('Epoch') + self.fig.axes[splot_i].set_ylabel('Value') + self.fig.axes[splot_i].set_title(splot_name) + splot_i = splot_i + 1 + #self.fig.axes[splot_i].set_title('Tracking %d channels' % len(chan_list)) + + self.__common_ui__() + self.canvas.draw() + self.updater_thread.start() + + def start(self): + self.show() + self.updater_thread.start() + self.app.exec_() + + class UpdaterThread(QtCore.QThread): + updated = QtCore.Signal() + + def __init__(self, live_mon, channel_list): + super(UpdaterThread, self).__init__() + self.live_mon = live_mon + self.channel_list = channel_list + + def run(self): + self.live_mon.update_channels(self.channel_list) # blocking + self.updated.emit()