From 00e94681e6e1bf1895a5bfe9f70e69fcf7ead788 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Mon, 31 Mar 2025 15:37:13 -0400 Subject: [PATCH 1/9] WIP: rotate --- XPointMLTest.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 99e4e07..ac3e18f 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -57,6 +57,19 @@ def expand_xpoints_mask(binary_mask, kernel_size=9): return expanded_mask +def rotate(frameData,deg): + if deg not in [90, 180, 270]: + print(f"invalid rotation specified... exiting") + sys.exit() + psi = frameData["psi"] + shape = psi.shape + if shape[1] != shape[2]: + print(f"input data must be square... exiting") + sys.exit() + frameData["psi"] = np.rot90(frameData["psi"][0], deg/90).reshape(shape) + frameData["mask"] = np.rot90(frameData["mask"][0], deg/90).reshape(shape) + return frameData + # DATASET DEFINITION class XPointDataset(Dataset): @@ -107,7 +120,13 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, # load all the data self.data = [] for fnum in fnumList: - self.data.append(self.load(fnum)) + frameData = self.load(fnum) + self.data.append(frameData) + self.data.append(rotate(frameData,90)) +# self.data.append(rotate(frameData,180)) +# self.data.append(rotate(frameData,270)) +# self.data.append(reflect(frameData,0)) +# self.data.append(reflect(frameData,1)) def __len__(self): return len(self.fnumList) From ed4999e792c2313086f4b5dbc9c7eaa7f5a6cb6d Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 11:40:51 -0400 Subject: [PATCH 2/9] rotated frames, cmd line arg to enable plotting, tweak file names --- XPointMLTest.py | 94 +++++++++++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index ac3e18f..70be108 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -14,6 +14,8 @@ import torch.optim as optim import torch.nn.functional as F +from torchvision.transforms import v2 # rotate tensor + from torch.utils.data import DataLoader, Dataset from timeit import default_timer as timer @@ -61,15 +63,18 @@ def rotate(frameData,deg): if deg not in [90, 180, 270]: print(f"invalid rotation specified... exiting") sys.exit() - psi = frameData["psi"] - shape = psi.shape - if shape[1] != shape[2]: - print(f"input data must be square... exiting") - sys.exit() - frameData["psi"] = np.rot90(frameData["psi"][0], deg/90).reshape(shape) - frameData["mask"] = np.rot90(frameData["mask"][0], deg/90).reshape(shape) - return frameData - + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + return { + "fnum": frameData["fnum"], + "rotation": deg, + "psi": psi, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } # DATASET DEFINITION class XPointDataset(Dataset): @@ -123,16 +128,13 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, frameData = self.load(fnum) self.data.append(frameData) self.data.append(rotate(frameData,90)) -# self.data.append(rotate(frameData,180)) -# self.data.append(rotate(frameData,270)) -# self.data.append(reflect(frameData,0)) -# self.data.append(reflect(frameData,1)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) def __len__(self): - return len(self.fnumList) + return len(self.data) def __getitem__(self, idx): - fnum = self.fnumList[idx] return self.data[idx] def load(self, fnum): @@ -240,6 +242,7 @@ def load(self, fnum): return { "fnum": fnum, + "rotation": 0, "psi": psi_torch, # shape [1, Nx, Ny] "mask": mask_torch, # shape [1, Nx, Ny] // Used in: psi, mask = batch["psi"].to(device), batch["mask"].to(device) "x": x, @@ -427,7 +430,7 @@ def forward(self, inputs, targets): return 1.0 - dice # PLOTTING FUNCTION -def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, interpFac, +def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, xpoint_mask=None, titleExtra="", outDir="plots", @@ -455,7 +458,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte if params["axisEqual"]: plt.gca().set_aspect("equal", "box") - plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}") + plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}, rotation={rotation}") # Overlay X-points if xpoint_mask is given if xpoint_mask is not None: @@ -473,7 +476,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, filenameBase, inte basename = os.path.basename(filenameBase) saveFilename = os.path.join( outDir, - f"{basename}_interpFac_{interpFac}_{fnum:04d}{titleExtra.replace(' ','_')}.png" + f"{basename}_interpFac_{interpFac}_frame{fnum:04d}_rotation{rotation}_{titleExtra.replace(' ','_')}.png" ) plt.savefig(saveFilename, dpi=300) print(" Figure written to", saveFilename) @@ -621,6 +624,8 @@ def parseCommandLineArgs(): specify the path to a directory that will be used to cache the outputs of the analytic Xpoint finder ''') + parser.add_argument('--plot', type=bool, default=False, + help='create figures of the ground truth X-points and model identified X-points') args = parser.parse_args() return args @@ -694,6 +699,7 @@ def main(): val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") + plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) requiredLossDecreaseMagnitude = args.minTrainingLoss @@ -719,6 +725,7 @@ def main(): for item in set: # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params fnum = item["fnum"] + rotation = item["rotation"] psi_np = np.array(item["psi"])[0] mask_gt = np.array(item["mask"])[0] x = item["x"] @@ -738,7 +745,7 @@ def main(): pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune - print(f"Frame {fnum}:") + print(f"Frame {fnum} rotation {rotation}:") print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") @@ -746,30 +753,31 @@ def main(): print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels") print(f" Binary Mask (X_points) - shape: {pred_mask_bin.shape}, min: {pred_mask_bin.min()}, max: {pred_mask_bin.max()}") - # Plot GROUND TRUTH - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, filenameBase, interpFac, - xpoint_mask=mask_gt, - titleExtra="(GT X-points)", - outDir=outDir, - saveFig=True - ) - - # Plot CNN PREDICTIONS - plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, filenameBase, interpFac, - xpoint_mask=np.squeeze(pred_mask_bin), - titleExtra="(CNN X-points)", - outDir=outDir, - saveFig=True - ) - - pred_prob_np_full = pred_prob.cpu().numpy() - plot_model_performance( - psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, - outDir=outDir, - saveFig=True - ) + if args.plot : + # Plot GROUND TRUTH + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, + xpoint_mask=mask_gt, + titleExtra="GTXpoints", + outDir=outDir, + saveFig=True + ) + + # Plot CNN PREDICTIONS + plot_psi_contours_and_xpoints( + psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, + xpoint_mask=np.squeeze(pred_mask_bin), + titleExtra="CNNXpoints", + outDir=outDir, + saveFig=True + ) + + pred_prob_np_full = pred_prob.cpu().numpy() + plot_model_performance( + psi_np, pred_prob_np_full, mask_gt, x, y, params, fnum, filenameBase, + outDir=outDir, + saveFig=True + ) t5 = timer() print("time (s) to apply model: " + str(t5-t4)) From 6699a907da4dde4f4e44798cd3a9815b93d6389f Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 13:46:07 -0400 Subject: [PATCH 3/9] write to plots dir --- XPointMLTest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 70be108..03268d6 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -568,7 +568,7 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi plt.close() -def plot_training_history(train_losses, val_losses, save_path='output_images/training_history.png'): +def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'): """ Plots training and validation losses across epochs. From 4b6f9333c6d633a4a4c494a5b7ac4773399e7f8d Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 14:07:10 -0400 Subject: [PATCH 4/9] write all figures to same dir --- XPointMLTest.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 03268d6..3242a6b 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -118,9 +118,6 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, self.params["symBar"] = 1 self.params["colormap"] = 'bwr' - # output directory: - self.outDir = "plots" - os.makedirs(self.outDir, exist_ok=True) # load all the data self.data = [] @@ -626,6 +623,8 @@ def parseCommandLineArgs(): ''') parser.add_argument('--plot', type=bool, default=False, help='create figures of the ground truth X-points and model identified X-points') + parser.add_argument('--plotDir', type=Path, default="./plots", + help='directory where figures are written') args = parser.parse_args() return args @@ -666,6 +665,10 @@ def main(): args = parseCommandLineArgs() checkCommandLineArgs(args) + # output directory: + outDir = args.plotDir + os.makedirs(outDir, exist_ok=True) + t0 = timer() train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) @@ -710,8 +713,7 @@ def main(): # (D) Plotting after training model.eval() - outDir = "output_images" - os.makedirs(outDir, exist_ok=True) + outDir = "plots" interpFac = 1 # Evaluate on combined set for demonstration. Exam this part to see if save to remove From c6ac01a7e4608a5542ce74bc8abc5778befcc5e4 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 15:19:04 -0400 Subject: [PATCH 5/9] reflect on x and y axis --- XPointMLTest.py | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 3242a6b..ead0e5e 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -68,6 +68,25 @@ def rotate(frameData,deg): return { "fnum": frameData["fnum"], "rotation": deg, + "reflectionAxis": -1, # no reflection + "psi": psi, + "mask": mask, + "x": frameData["x"], + "y": frameData["y"], + "filenameBase": frameData["filenameBase"], + "params": frameData["params"] + } + +def reflect(frameData,axis): + if axis not in [0,1]: + print(f"invalid reflection axis specified... exiting") + sys.exit() + psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) + mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + return { + "fnum": frameData["fnum"], + "rotation": 0, + "reflectionAxis": axis, "psi": psi, "mask": mask, "x": frameData["x"], @@ -127,6 +146,8 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, self.data.append(rotate(frameData,90)) self.data.append(rotate(frameData,180)) self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -240,6 +261,7 @@ def load(self, fnum): return { "fnum": fnum, "rotation": 0, + "reflectionAxis": -1, # no reflection "psi": psi_torch, # shape [1, Nx, Ny] "mask": mask_torch, # shape [1, Nx, Ny] // Used in: psi, mask = batch["psi"].to(device), batch["mask"].to(device) "x": x, @@ -427,7 +449,8 @@ def forward(self, inputs, targets): return 1.0 - dice # PLOTTING FUNCTION -def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, +def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, + reflectionAxis, filenameBase, interpFac, xpoint_mask=None, titleExtra="", outDir="plots", @@ -455,7 +478,8 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, filename if params["axisEqual"]: plt.gca().set_aspect("equal", "box") - plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}, rotation={rotation}") + plt.title(f"Vector Potential Contours {titleExtra}, fileNum={fnum}, " + f"reflectionAxis={reflectionAxis}") # Overlay X-points if xpoint_mask is given if xpoint_mask is not None: @@ -473,7 +497,7 @@ def plot_psi_contours_and_xpoints(psi_np, x, y, params, fnum, rotation, filename basename = os.path.basename(filenameBase) saveFilename = os.path.join( outDir, - f"{basename}_interpFac_{interpFac}_frame{fnum:04d}_rotation{rotation}_{titleExtra.replace(' ','_')}.png" + f"{basename}_interpFac_{interpFac}_frame{fnum:04d}_rotation{rotation}_reflection{reflectionAxis}_{titleExtra.replace(' ','_')}.png" ) plt.savefig(saveFilename, dpi=300) print(" Figure written to", saveFilename) @@ -728,6 +752,7 @@ def main(): # item is a dict with keys: fnum, psi, mask, psi_np, mask_np, x, y, tmp, params fnum = item["fnum"] rotation = item["rotation"] + reflectionAxis = item["reflectionAxis"] psi_np = np.array(item["psi"])[0] mask_gt = np.array(item["mask"])[0] x = item["x"] @@ -747,7 +772,7 @@ def main(): pred_mask_bin = (pred_prob_np > 0.5).astype(np.float32) # Thresholding at 0.5, can be fine tune - print(f"Frame {fnum} rotation {rotation}:") + print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:") print(f"psi shape: {psi_np.shape}, min: {psi_np.min()}, max: {psi_np.max()}") print(f"pred_bin shape: {pred_bin.shape}, min: {pred_bin.min()}, max: {pred_bin.max()}") print(f" Logits - min: {pred_mask_np.min():.5f}, max: {pred_mask_np.max():.5f}, mean: {pred_mask_np.mean():.5f}") @@ -758,7 +783,7 @@ def main(): if args.plot : # Plot GROUND TRUTH plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, xpoint_mask=mask_gt, titleExtra="GTXpoints", outDir=outDir, @@ -767,7 +792,7 @@ def main(): # Plot CNN PREDICTIONS plot_psi_contours_and_xpoints( - psi_np, x, y, params, fnum, rotation, filenameBase, interpFac, + psi_np, x, y, params, fnum, rotation, reflectionAxis, filenameBase, interpFac, xpoint_mask=np.squeeze(pred_mask_bin), titleExtra="CNNXpoints", outDir=outDir, From 6a0ff72cbe29a9edf5aebd499d16dd89f2b554d6 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 15:26:28 -0400 Subject: [PATCH 6/9] don't need to add more validation frames --- XPointMLTest.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index ead0e5e..4544e21 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -106,7 +106,7 @@ class XPointDataset(Dataset): - Returns (psiTensor, maskTensor) as a PyTorch (float) pair. """ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, - saveFig=1, xptCacheDir=None): + saveFig=1, xptCacheDir=None, rotateAndReflect=False): """ paramFile: Path to parameter file (string). fnumList: List of frames to iterate. @@ -143,11 +143,12 @@ def __init__(self, paramFile, fnumList, constructJz=1, interpFac=1, for fnum in fnumList: frameData = self.load(fnum) self.data.append(frameData) - self.data.append(rotate(frameData,90)) - self.data.append(rotate(frameData,180)) - self.data.append(rotate(frameData,270)) - self.data.append(reflect(frameData,0)) - self.data.append(reflect(frameData,1)) + if rotateAndReflect: + self.data.append(rotate(frameData,90)) + self.data.append(rotate(frameData,180)) + self.data.append(rotate(frameData,270)) + self.data.append(reflect(frameData,0)) + self.data.append(reflect(frameData,1)) def __len__(self): return len(self.data) @@ -698,7 +699,8 @@ def main(): val_fnums = range(args.validationFrameFirst, args.validationFrameLast) train_dataset = XPointDataset(args.paramFile, train_fnums, constructJz=1, - interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir) + interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir, + rotateAndReflect=True) val_dataset = XPointDataset(args.paramFile, val_fnums, constructJz=1, interpFac=1, saveFig=1, xptCacheDir=args.xptCacheDir) From 843624b36a5038a90f08904103440e664a7b5245 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 1 Apr 2025 16:26:35 -0400 Subject: [PATCH 7/9] attempt to use exponential learning rate --- XPointMLTest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/XPointMLTest.py b/XPointMLTest.py index 4544e21..b125286 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -715,6 +715,7 @@ def main(): criterion = DiceLoss(smooth=1.0) optimizer = optim.Adam(model.parameters(), lr=1e-5) + torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, verbose=True) t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) From 07b3d93f6e64a08546fb824a3322fcab9952ac8b Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Wed, 2 Apr 2025 07:03:31 -0400 Subject: [PATCH 8/9] training loss reduced by a factor of three by 550 epochs --- XPointMLTest.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index b125286..9ce4748 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -715,7 +715,7 @@ def main(): criterion = DiceLoss(smooth=1.0) optimizer = optim.Adam(model.parameters(), lr=1e-5) - torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96, verbose=True) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96) t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) @@ -727,7 +727,9 @@ def main(): for epoch in range(num_epochs): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) - print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} ValLoss={val_loss[-1]}") + lr = scheduler.get_last_lr() + print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} " + f"ValLoss={val_loss[-1]} LearningRate {lr}") plot_training_history(train_loss, val_loss) print("time (s) to train model: " + str(timer()-t2)) From bf16650a5a77a0c95eb374a3dba999bc6e025fc9 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Wed, 2 Apr 2025 07:15:30 -0400 Subject: [PATCH 9/9] need to call step to update the learning rate --- XPointMLTest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/XPointMLTest.py b/XPointMLTest.py index 9ce4748..614f653 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -728,6 +728,7 @@ def main(): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) val_loss.append(validate_one_epoch(model, val_loader, criterion, device)) lr = scheduler.get_last_lr() + scheduler.step() # update lr with gamma print(f"[Epoch {epoch+1}/{num_epochs}] TrainLoss={train_loss[-1]} " f"ValLoss={val_loss[-1]} LearningRate {lr}")