forked from hanzhu97702/IEEE_TGRS_MUNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
64 lines (59 loc) · 2.49 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from torch.nn import init
def Init_Weights(net, init_type, gain):
print('Init Network Weights')
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1 or classname.find('BatchNorm1d') != -1:
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
class MUNet(nn.Module):
def __init__(self, band, num_classes, ldr_dim, reduction):
super(MUNet, self).__init__()
self.fc_hsi = nn.Sequential(
nn.Conv2d(band, band//2, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(band//2),
nn.ReLU(),
nn.Conv2d(band//2, band//4, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(band//4),
nn.ReLU(),
nn.Conv2d(band//4, num_classes, kernel_size=1, stride=1, padding=0)
)
self.softmax = nn.Softmax(dim=1)
self.spectral_fe = nn.Sequential(
nn.Conv2d(ldr_dim, num_classes//reduction, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(num_classes//reduction),
nn.ReLU(),
nn.Conv2d(num_classes//reduction, num_classes, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(num_classes),
nn.ReLU(),
)
self.spectral_se = nn.Sequential(
nn.Conv2d(num_classes, num_classes//reduction, kernel_size=1, padding=0),
nn.ReLU(),
nn.Conv2d(num_classes//reduction, num_classes, kernel_size=1, padding=0),
nn.Sigmoid()
)
self.decoder = nn.Sequential(
nn.Conv2d(num_classes, band, kernel_size=1, stride=1, bias=False),
nn.ReLU(),
)
def forward(self, x, y):
encode = self.fc_hsi(x)
## spectral attention
y_fe = self.spectral_fe(y)
attention = self.spectral_se(y_fe)
abu = self.softmax(torch.mul(encode, attention))
# 矩阵点积计算
output = self.decoder(abu)
return abu, output