Skip to content

Commit

Permalink
Merge pull request #33 from qwertpi/master
Browse files Browse the repository at this point in the history
PR to add optional colour map parameter to display_activations
  • Loading branch information
Philippe Rémy authored Apr 16, 2019
2 parents 1cd90f1 + b9e323f commit 358352a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions keract/keract.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ def get_activations(model, x, layer_name=None):
return result


def display_activations(activations, save=False):
def display_activations(activations, cmap=None, save=False):
"""
Plot heatmaps of activations for all filters for each layer
:param activations: dict mapping layers to corresponding activations (1, output_h, output_w, num_filters)
:param cmap: string - a valid matplotlib colourmap to be used
:param save: bool- if the plot should be saved
:return: None
"""
Expand All @@ -107,7 +108,7 @@ def display_activations(activations, save=False):
for i in range(nrows * ncols):
if i < acts.shape[-1]:
img = acts[0, :, :, i]
hmap = axes.flat[i].imshow(img)
hmap = axes.flat[i].imshow(img,cmap=cmap)
axes.flat[i].axis('off')
fig.subplots_adjust(right=0.8)
cbar = fig.add_axes([0.85, 0.15, 0.03, 0.7])
Expand All @@ -116,6 +117,8 @@ def display_activations(activations, save=False):
plt.savefig(layer_name.split('/')[0] + '.png', bbox_inches='tight')
else:
plt.show()
#pyplot figures require manual closing
plt.close(fig)


def display_heatmaps(activations, image, save=False):
Expand Down Expand Up @@ -165,6 +168,7 @@ def display_heatmaps(activations, image, save=False):
plt.savefig(layer_name.split('/')[0] + '.png', bbox_inches='tight')
else:
plt.show()
plt.close(fig)


def display_gradients_of_trainable_weights(gradients, save=False):
Expand Down Expand Up @@ -197,3 +201,5 @@ def display_gradients_of_trainable_weights(gradients, save=False):
plt.savefig(layer_name.split('/')[0] + '.png', bbox_inches='tight')
else:
plt.show()
plt.close(fig)

0 comments on commit 358352a

Please sign in to comment.