-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Off-resonance array backend checks. #212
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I think just using |
I am setting up the mri-nufft operator such that it is operable with torch's autograd. Here is what I did:
I had to change part of the autodiff.py code into: With the set_context() method added and ctx object taken out of forward method. Kept the code as it was will give While this makes the off-resonance compensated nufft operator run with torch.autograd on the orc_nufft without error. But the print out statement shown backward is never reached. The image x stayed at the initial guess. I used a quite standard torch optimization structure which is shown below: import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
def nufft_forward(x):
y = orc_nufft.op(x)
return y
def nufft_adjoint(y):
x = orc_nufft.adj_op(y)
return x
def reconstruct_image(kspace_data, num_iterations=100, learning_rate=1e-3, regularization_weight=0.01):
x_init = torch.zeros(64,64, requires_grad=True)
kspace_data.requires_grad_(True)
x = x_init
# Define the optimizer
optimizer = optim.Adam([x], lr=learning_rate)
# Reconstruction loop
for iter in range(num_iterations):
optimizer.zero_grad()
# Forward pass: compute the predicted k-space data
print(x.requires_grad)
y_pred = nufft_forward(x)
y_pred.requires_grad_(True)
x.requires_grad_(True)
print(x.requires_grad)
# Data fidelity term: L2 norm between predicted and measured k-space data
data_fidelity = torch.norm(y_pred - kspace_data) ** 2
# Regularization term: Total Variation (TV) regularization
tv_reg = F.l1_loss(x[1:, :], x[:-1, :]) + F.l1_loss(x[:, 1:], x[:, :-1])
# Total loss
loss = data_fidelity + regularization_weight * tv_reg
print(f'"loss grad: {loss.requires_grad}')
# Backward pass
loss.backward()
# Update the image
print(x.requires_grad)
optimizer.step()
# Optionally, print the loss
if (iter + 1) % 10 == 0:
print(f"Iteration {iter + 1}/{num_iterations}, Loss: {loss.item():.4f}")
reconstructed_image = x.detach()
return reconstructed_image
if __name__ == "__main__":
kspace_data = torch.from_numpy(kspace)
reconstructed_image = reconstruct_image(
kspace_data,
num_iterations=1000,
learning_rate=0.01,
regularization_weight=0.01
) I am not sure if this comes from my wrongful usage or generation of the operator or I probably had not understand the reconstruction problem well enough. BTW, is there test or example being done with the iterative reconstruction using the off-resonance NUFFT from MRI-NUFFT? I would love to learn more on it. I fed the reconstruction inverse problem with our operator into some other solver such as the solve_norm_cg in torchopt. Similar result is produced where the result image is the same with the initial guess and the backward is never reached. Very sorry for this being this long of a message. It is quite late at where I am and my communication ability has dropped significantly. If the operator can work well with the torch based method, half of my problem could be solved, and I feel like we are getting close. Wenshang |
(I took the liberty of editing your comment for formatting purposes) So If I understood your code well you want to use autodiff( wrt to data) with off-resonnance correction. I think we miss a proper implementation of that, and there might be an error on the order of initialization. I think what we want is
To do so, its probably best to things manually here: from mrinufft.operators import get_operator, MRIFourierCorrected
from mrinufft.operators.autodiff import MRINufftAutograd
nufft_base = get_operator("cufinufft", samples, shape, n_coils, n_batchs,smaps=smaps)
nufft_with_orc = MRIFourierCorrected(nufft_base, b0_map, readout_time, ...)
nufft_with_orc_and_autodiff = MRINufftAutoGrad(nufft_with_orc, wrt_data=True)
Not yet, but I think that your comment is an early draft for one, if you could extend it (and use standard phantom data available from You can probably take a deeper look at for inspiration |
Do you want to cover adding example in this PR, or can we merge this? |
Also @wenshangwang if you are interested, feel free let us know if you want to make a short example. We can help you along if required. |
we can merge I think, and write an example later (ie in a new PR) |
fixes #203.