-
Notifications
You must be signed in to change notification settings - Fork 57
/
DG_aug.py
executable file
·73 lines (62 loc) · 2.76 KB
/
DG_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from loss_dp import DPLoss
import network as net
from contrastive_loss_m import SupConLoss_m
class DDLearn(nn.Module):
def __init__(self, n_feature, n_act_class, n_aug_class, dataset, dp):
super(DDLearn, self).__init__()
self.n_feature = n_feature
self.n_act_class = n_act_class
self.n_aug_class = n_aug_class
self.dataset = dataset
self.dp = dp
if dataset == 'uschad':
self.feature_module = net.Network_usc(n_feature, dataset)
else:
self.feature_module = net.Network(n_feature, dataset)
self.act_cls = nn.Linear(n_feature, n_act_class)
self.aug_cls = nn.Linear(n_feature, n_aug_class)
self.criterion = nn.CrossEntropyLoss()
self.criterion_a = nn.CrossEntropyLoss()
self.con = SupConLoss_m(contrast_mode='all')
self.params = [
{'params': self.feature_module.parameters()},
{'params': self.act_cls.parameters()},
{'params': self.aug_cls.parameters()},
]
def forward(self, x_ori, x_onlyaug, labels):
"""the forward of model
"""
actlabel_ori, actlabel_aug, auglabel_ori, auglabel_aug = labels
feature_ori = self.feature_module(x_ori)
auglabel_true = torch.cat((auglabel_ori, auglabel_aug), dim=0)
feature_aug = self.feature_module(x_onlyaug)
feature_aug_task = torch.cat((feature_ori, feature_aug), dim=0)
auglabel_p = self.predict_aug(feature_aug_task)
feature_act_task = feature_aug_task
actlabel_true = torch.cat((actlabel_ori, actlabel_aug), dim=0)
actlabel_p = self.predict_act(feature_act_task)
loss_c = self.criterion(actlabel_p, actlabel_true)
loss_selfsup = self.criterion_a(auglabel_p, auglabel_true)
loss_dp = torch.zeros(1).cuda()
if self.dp != 'no':
dp_layer = DPLoss(
loss_type=self.dp, input_dim=self.n_feature)
loss_dp = dp_layer.compute(feature_ori, feature_aug)
con_loss = self.con(torch.cat([feature_ori.unsqueeze(1), feature_aug.unsqueeze(
1)], dim=1), torch.cat([actlabel_ori, actlabel_aug]))
return actlabel_p, loss_c, loss_selfsup, loss_dp, con_loss
def test_predict(self, x_ori, x_aug):
actlabel_p = self.act_cls(self.feature_module(x_ori))
auglabel_p = self.aug_cls(self.feature_module(
torch.cat((x_ori, x_aug), dim=0)))
return actlabel_p, auglabel_p
def predict_act(self, feature):
act_predict = self.act_cls(feature)
return act_predict
def predict_aug(self, feature):
aug_predict = self.aug_cls(feature)
return aug_predict