-
Notifications
You must be signed in to change notification settings - Fork 51
/
pdnet.py
63 lines (60 loc) · 2.12 KB
/
pdnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from .cnn import CNNComplex
from .cross_domain import CrossDomainNet
from ..utils.fourier import FFT, IFFT
class PDNet(CrossDomainNet):
def __init__(
self,
n_filters=32,
n_primal=5,
n_dual=5,
n_iter=10,
primal_only=False,
activation='relu',
multicoil=False,
**kwargs,
):
self.n_filters = n_filters
self.n_primal = n_primal
self.n_dual = n_dual
self.n_iter = n_iter
self.primal_only = primal_only
self.activation = activation
self.multicoil = multicoil
super(PDNet, self).__init__(
domain_sequence='KI'*self.n_iter,
data_consistency_mode='measurements_residual',
i_buffer_mode=True,
k_buffer_mode=not self.primal_only,
i_buffer_size=self.n_primal,
k_buffer_size=self.n_dual,
multicoil=self.multicoil,
**kwargs,
)
self.op = FFT(masked=True, multicoil=self.multicoil)
self.adj_op = IFFT(masked=True, multicoil=self.multicoil)
self.image_net = [CNNComplex(
n_convs=3,
n_filters=self.n_filters,
n_output_channels=self.n_primal,
activation='relu',
res=True,
name=f'image_net_{i}',
) for i in range(self.n_iter)]
if not self.primal_only:
# TODO: check that when multicoil we do not have this
self.kspace_net = [CNNComplex(
n_convs=3,
n_filters=self.n_filters,
n_output_channels=self.n_dual,
activation='relu',
res=True,
name=f'kspace_net_{i}',
) for i in range(self.n_iter)]
else:
# TODO: check n dual
# TODO: code small diff function
self.kspace_net = [measurements_residual for i in range(self.n_iter)]
def measurements_residual(concatenated_kspace):
current_kspace = concatenated_kspace[..., 0:1]
original_kspace = concatenated_kspace[..., 1:2]
return current_kspace - original_kspace