Skip to content

Commit

Permalink
getting ready for a new tag
Browse files Browse the repository at this point in the history
  • Loading branch information
Imraj-Singh committed Sep 25, 2024
1 parent 5c3051c commit 6d58c5f
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 120 deletions.
6 changes: 3 additions & 3 deletions bsrem_bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def update(self):
self.compute_rdp_diag_hess = False
self.eps = self.dataset.OSEM_image.max()/1e3
x_norm = self.x.norm()
print("prior: ", prior_grad.norm(), " lhkd: ", lhkd_grad.norm(), " x: ", x_norm, " g: ", g.norm(), " prior/x: ", prior_grad.norm()/x_norm, " lhkd/x: ", lhkd_grad.norm()/x_norm, " g/x: ", g.norm()/x_norm)
print("prior/lhkd: ", prior_grad.norm()/lhkd_grad.norm(), " prior/g: ", prior_grad.norm()/g.norm(), " lhkd/g: ", lhkd_grad.norm()/g.norm())
#print("prior: ", prior_grad.norm(), " lhkd: ", lhkd_grad.norm(), " x: ", x_norm, " g: ", g.norm(), " prior/x: ", prior_grad.norm()/x_norm, " lhkd/x: ", lhkd_grad.norm()/x_norm, " g/x: ", g.norm()/x_norm)
#print("prior/lhkd: ", prior_grad.norm()/lhkd_grad.norm(), " prior/g: ", prior_grad.norm()/g.norm(), " lhkd/g: ", lhkd_grad.norm()/g.norm())

#g.multiply(self.x + self.eps, out=self.x_update)
#self.x_update.divide(self.average_sensitivity, out=self.x_update)
Expand All @@ -154,7 +154,7 @@ def update(self):

step_size = alpha_long #np.sqrt(alpha_long*alpha_short)
#print("step size: ", step_size)
print("step size: ", step_size)
#print("step size: ", step_size)

self.x_prev = self.x.copy()
self.x_update_prev = self.x_update.copy()
Expand Down
102 changes: 27 additions & 75 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
"""Main file to modify for submissions.
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows:
>>> from main import Submission, submission_callbacks
>>> from petric import data, metrics
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks


from bsrem_saga import SAGA
from utils.number_of_subsets import compute_number_of_subsets

from sirf.contrib.partitioner import partitioner
#from utils.partioner_function import data_partition
#from utils.partioner_function_no_obj import data_partition

assert issubclass(SAGA, Algorithm)
from bsrem_bb import BSREM


import torch
torch.cuda.set_per_process_memory_fraction(0.8)

import setup_model
assert issubclass(BSREM, Algorithm)


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration
Expand All @@ -28,75 +30,25 @@ def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration


class Submission(SAGA):
class Submission(BSREM):
def __init__(self, data,
update_objective_interval: int = 10,
**kwargs):

tof = (data.acquired_data.shape[0] > 1)
views = data.acquired_data.shape[2]
num_subsets = compute_number_of_subsets(views, tof)

update_objective_interval: int = 1,
**kwargs):

data_sub, _, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets,
initial_image=data.OSEM_image,
mode = "staggered")

self.dataset = data

# WARNING: modifies prior strength with 1/num_subsets
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(data_sub))
data.mult_factors, 1,
initial_image=data.OSEM_image)
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations
data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)

sensitivity = data.OSEM_image.get_uniform_copy(0)
for s in range(len(data_sub)):
obj_funs[s].set_up(data.OSEM_image)
sensitivity.add(obj_funs[s].get_subset_sensitivity(0), out=sensitivity)

pll_grad = data.OSEM_image.get_uniform_copy(0)
for s in range(len(data_sub)):
pll_grad.add(obj_funs[s].gradient(data.OSEM_image), out=pll_grad)

average_sensitivity = sensitivity.clone() / num_subsets
average_sensitivity += average_sensitivity.max()/1e4

sensitivity += sensitivity.max()/1e4
eps = data.OSEM_image.max()/1e3

prior_grad = data.prior.gradient(data.OSEM_image) * num_subsets

grad = (data.OSEM_image + eps) * pll_grad / sensitivity
prior_grad = (data.OSEM_image + eps) * prior_grad / sensitivity

DEVICE = "cuda"

initial_images = torch.from_numpy(data.OSEM_image.as_array()).float().to(DEVICE).unsqueeze(0)
prior_grads = torch.from_numpy(prior_grad.as_array()).float().to(DEVICE).unsqueeze(0)
pll_grads = torch.from_numpy(grad.as_array()).float().to(DEVICE).unsqueeze(0)

model_inp = torch.cat([initial_images, pll_grads, prior_grads], dim=0).unsqueeze(0)
with torch.no_grad():
x_pred = setup_model.network_precond(model_inp)
x_pred[x_pred < 0] = 0

#del setup_model.network_precond
del initial_images
del prior_grads
del pll_grads
del model_inp

initial = data.OSEM_image.clone()
initial.fill(x_pred.detach().cpu().numpy().squeeze())

for f in obj_funs: # add prior evenly to every objective function
f.set_prior(data.prior)
self.dataset = data

super().__init__(data=data_sub,
obj_funs=obj_funs,
initial=initial,
average_sensitivity=average_sensitivity,
super().__init__(data_sub,
obj_funs,
initial=data.OSEM_image,
update_objective_interval=update_objective_interval)

submission_callbacks = []

submission_callbacks = [] #[MaxIteration(660)]
Loading

0 comments on commit 6d58c5f

Please sign in to comment.