Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexdenker committed Sep 24, 2024
1 parent bdf2b26 commit ff6cc74
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion bsrem_bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def update(self):
delta_x = self.x - self.x_prev
delta_g = self.x_update_prev - self.x_update

dot_product = delta_g.dot(delta_x)
dot_product = delta_g.dot(delta_x) # (deltag * deltax).sum()
alpha_long = delta_x.norm()**2 / np.abs(dot_product)
#dot_product = delta_x.dot(delta_g)
#alpha_short = np.abs((dot_product).sum()) / delta_g.norm()**2
Expand Down
27 changes: 22 additions & 5 deletions bsrem_saga.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, data, initial, average_sensitivity,
self.sum_gm = self.x.get_uniform_copy(0)
self.x_update = self.x.get_uniform_copy(0)

self.last_objective_function = self.objective_function_inter(self.x)
self.gamma = 1.0 # scaling for learning rate

def subset_sensitivity(self, subset_num):
raise NotImplementedError

Expand All @@ -82,6 +85,15 @@ def epoch(self):
return self.iteration // self.num_subsets

def update(self):
if self.epoch() % 4 == 0 and self.iteration % self.num_subsets == 0 and self.epoch() > 0:
loss = self.objective_function_inter(self.x)
#print("Objective at ", self.epoch(), " is = ", loss)

if loss < self.last_objective_function:
#print("Reduce learning rate!")
self.gamma = self.gamma * 0.75

self.last_objective_function = loss
# for the first epochs just do SGD
if self.epoch() < 2:
# construct gradient of subset
Expand Down Expand Up @@ -143,7 +155,7 @@ def update(self):
if self.update_filter is not None:
self.update_filter.apply(self.x_update)

self.x.sapyb(1.0, self.x_update, self.alpha, out=self.x)
self.x.sapyb(1.0, self.x_update, self.gamma*self.alpha, out=self.x)
#self.x += self.alpha * self.x_update

# threshold to non-negative
Expand Down Expand Up @@ -175,11 +187,8 @@ def update(self):

# DOG lr
self.alpha = self.max_distance / np.sqrt(self.sum_gradient)

#if self.alpha > self.last_alpha:
# self.sum_gradient += 0.0001 * self.sum_gradient

self.x.sapyb(1.0, self.x_update, self.alpha, out=self.x)
self.x.sapyb(1.0, self.x_update, self.gamma*self.alpha, out=self.x)
#self.x += self.alpha * self.x_update

# threshold to non-negative
Expand All @@ -203,6 +212,14 @@ def objective_function(self, x):
# v += self.subset_objective(x, s)
return v

def objective_function_inter(self, x):
''' value of objective function summed over all subsets '''
v = 0
for s in range(len(self.data)):
v += self.subset_objective(x, s)
return v


def subset_objective(self, x, subset_num):
''' value of objective function for one subset '''
raise NotImplementedError
Expand Down
Binary file added checkpoint/model.pt
Binary file not shown.
Binary file added checkpoint/multi_step_model.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion main_EWS_SAGA.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, data,
x_pred = setup_model.network_precond(model_inp)
x_pred[x_pred < 0] = 0

del setup_model.network_precond
#del setup_model.network_precond
del initial_images
del prior_grads
del pll_grads
Expand Down
2 changes: 1 addition & 1 deletion main_Full_Gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks

#from sirf.contrib.partitioner import partitioner
from sirf.contrib.partitioner import partitioner

from bsrem_bb import BSREM

Expand Down

0 comments on commit ff6cc74

Please sign in to comment.