Skip to content

Commit c800f93

Browse files
author
John Halloran
committed
feat: add live plotting of updates
1 parent 1b49701 commit c800f93

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

src/diffpy/snmf/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
init_weights=init_weights_file,
1313
init_components=init_components_file,
1414
init_stretch=init_stretch_file,
15+
show_plots=True,
1516
)
1617

1718
print("Done")

src/diffpy/snmf/plotter.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# helper_plot.py
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
6+
class SNMFPlotter:
7+
def __init__(self, figsize=(12, 4)):
8+
plt.ion()
9+
self.fig, self.axes = plt.subplots(1, 3, figsize=figsize)
10+
titles = ["Components", "Weights (rows as series)", "Stretch (rows as series)"]
11+
for ax, t in zip(self.axes, titles):
12+
ax.set_title(t)
13+
self.lines = {"components": [], "weights": [], "stretch": []}
14+
self._layout_done = False
15+
plt.show()
16+
17+
def _ensure_lines(self, ax, key, n_series):
18+
cur = self.lines[key]
19+
if len(cur) != n_series:
20+
ax.cla()
21+
ax.set_title(ax.get_title())
22+
self.lines[key] = [ax.plot([], [])[0] for _ in range(n_series)]
23+
return self.lines[key]
24+
25+
def _update_series(self, ax, key, data_2d):
26+
# Expect rows = separate series for components
27+
data_2d = np.atleast_2d(data_2d)
28+
n_series, n_pts = data_2d.shape
29+
lines = self._ensure_lines(ax, key, n_series)
30+
x = np.arange(n_pts)
31+
for ln, y in zip(lines, data_2d):
32+
ln.set_data(x, y)
33+
ax.relim()
34+
ax.autoscale_view()
35+
36+
def update(self, components, weights, stretch, update_tag=None):
37+
# Components: transpose before plotting
38+
C = np.asarray(components).T
39+
self._update_series(self.axes[0], "components", C)
40+
41+
W = np.asarray(weights)
42+
self._update_series(self.axes[1], "weights", W)
43+
44+
S = np.asarray(stretch)
45+
self._update_series(self.axes[2], "stretch", S)
46+
47+
if update_tag is not None:
48+
self.fig.suptitle(f"Updated: {update_tag}", fontsize=14)
49+
50+
if not self._layout_done:
51+
self.fig.tight_layout()
52+
self._layout_done = True
53+
54+
self.fig.canvas.draw()
55+
self.fig.canvas.flush_events()
56+
plt.pause(0.001)

src/diffpy/snmf/snmf_class.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cvxpy as cp
22
import numpy as np
3+
from plotter import SNMFPlotter
34
from scipy.optimize import minimize
45
from scipy.sparse import coo_matrix, diags
56

@@ -73,6 +74,7 @@ def __init__(
7374
tol=5e-7,
7475
n_components=None,
7576
random_state=None,
77+
show_plots=False,
7678
):
7779
"""Initialize an instance of SNMF and run the optimization.
7880
@@ -112,6 +114,8 @@ def __init__(
112114
random_state : int Optional Default = None
113115
The seed for the initial guesses at the matrices (A, X, and Y) created by
114116
the decomposition.
117+
show_plots : boolean Optional Default = False
118+
Enables plotting at each step of the decomposition.
115119
"""
116120

117121
self.source_matrix = source_matrix
@@ -123,6 +127,7 @@ def __init__(
123127
self.signal_length, self.n_signals = source_matrix.shape
124128
self.num_updates = 0
125129
self._rng = np.random.default_rng(random_state)
130+
self.plotter = SNMFPlotter() if show_plots else None
126131

127132
# Enforce exclusive specification of n_components or init_weights
128133
if (n_components is None and init_weights is None) or (
@@ -236,6 +241,13 @@ def normalize_results(self):
236241
print(f"Objective function after normalize_components: {self.objective_function:.5e}")
237242
self._objective_history.append(self.objective_function)
238243
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
244+
if self.plotter is not None:
245+
self.plotter.update(
246+
components=self.components,
247+
weights=self.weights,
248+
stretch=self.stretch,
249+
update_tag="normalize components",
250+
)
239251
if self.objective_difference < self.objective_function * self.tol and outiter >= 7:
240252
break
241253

@@ -252,6 +264,10 @@ def outer_loop(self):
252264
if self.objective_function < self.best_objective:
253265
self.best_objective = self.objective_function
254266
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
267+
if self.plotter is not None:
268+
self.plotter.update(
269+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="components"
270+
)
255271

256272
self.update_weights()
257273
self.residuals = self.get_residual_matrix()
@@ -262,6 +278,10 @@ def outer_loop(self):
262278
if self.objective_function < self.best_objective:
263279
self.best_objective = self.objective_function
264280
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
281+
if self.plotter is not None:
282+
self.plotter.update(
283+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="weights"
284+
)
265285

266286
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
267287
if self._objective_history[-3] - self.objective_function < self.objective_difference * 1e-3:
@@ -276,6 +296,10 @@ def outer_loop(self):
276296
if self.objective_function < self.best_objective:
277297
self.best_objective = self.objective_function
278298
self.best_matrices = [self.components.copy(), self.weights.copy(), self.stretch.copy()]
299+
if self.plotter is not None:
300+
self.plotter.update(
301+
components=self.components, weights=self.weights, stretch=self.stretch, update_tag="stretch"
302+
)
279303

280304
def get_residual_matrix(self, components=None, weights=None, stretch=None):
281305
# Initialize residual matrix as negative of source_matrix

0 commit comments

Comments
 (0)