Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CBloss and DSAN models #139

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
391 changes: 391 additions & 0 deletions research/xidian/DSAN/README.md

Large diffs are not rendered by default.

391 changes: 391 additions & 0 deletions research/xidian/DSAN/README_CN.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions research/xidian/DSAN/checkpoint/Readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Store weights here
1 change: 1 addition & 0 deletions research/xidian/DSAN/data/Readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Store datasets here
52 changes: 52 additions & 0 deletions research/xidian/DSAN/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Dataloader'''
import os
from mindspore import dataset
from mindspore.dataset import transforms
from mindspore.dataset import vision
from mindspore import dtype as mstype


def load_training(root_path, dir, batch_size):
data = dataset.ImageFolderDataset(dataset_dir=os.path.join(root_path, dir,'images'), shuffle=True, decode=True)
transform_list = transforms.Compose(
[vision.Resize([256, 256]),
vision.RandomCrop(224),
vision.RandomHorizontalFlip(),
vision.ToTensor()])
image_folder_dataset = data.map(operations=transform_list, input_columns="image")
type_cast_op = transforms.TypeCast(mstype.int32)
image_folder_dataset = image_folder_dataset.map(operations=type_cast_op, input_columns="label", num_parallel_workers=1)
image_folder_dataset = image_folder_dataset.batch(batch_size=batch_size,drop_remainder=True)
return image_folder_dataset

def load_testing(root_path, dir, batch_size):
data = dataset.ImageFolderDataset(dataset_dir=os.path.join(root_path, dir, 'images'), shuffle=True, decode=True)
transform_list = transforms.Compose(
[vision.Resize([224, 224]),
vision.ToTensor()])
image_folder_dataset = data.map(operations=transform_list, input_columns="image")
type_cast_op = transforms.TypeCast(mstype.int32)
image_folder_dataset = image_folder_dataset.map(operations=type_cast_op, input_columns="label", num_parallel_workers=1)
image_folder_dataset = image_folder_dataset.batch(batch_size=batch_size)
return image_folder_dataset

def load_data(root_path, src, tar, batch_size):
loader_src = load_training(root_path, src, batch_size)
loader_tar = load_training(root_path, tar, batch_size)
loader_tar_test = load_testing(
root_path, tar, batch_size)
return loader_src, loader_tar, loader_tar_test
49 changes: 49 additions & 0 deletions research/xidian/DSAN/default_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "CPU" # Ascend
need_modelarts_dataset_unzip: True

# ==============================================================================
# export option
model_root: "checkpoint"
ckpt_file: "model.ckpt"
file_name: "net"
file_format: "MINDIR" # AIR,MINDIR,ONNX

# params for train
nepoch: 200
lr: [0.001, 0.01, 0.01]
seed: 2021
weight: 0.5
momentum: 0.9
decay: 5.0e-4
bottleneck: True
log_interval: 10

# params for dataset
nclass: 31
batch_size: 32
src: 'amazon'
tar: 'webcam'
dataset_path: 'data/OFFICE31'
image_height: 224
image_width: 224
####################
---

# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
113 changes: 113 additions & 0 deletions research/xidian/DSAN/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Evaluation for DSAN """
import os
import time
import numpy as np
from sklearn.metrics import accuracy_score
import mindspore as ms
from mindspore import context
from models.DSAN import DSAN
from data_loader import load_data
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num


def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
int(int(time.time() - s_time) % 60)))
print("Extract Done.")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")

if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)

sync_lock = "/tmp/unzip_sync.lock"

# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass

while True:
if os.path.exists(sync_lock):
break
time.sleep(1)

print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))


def test(model, dataloader):
acc = 0
preds = []
lables = []
for data in dataloader.create_dict_iterator():
data,label=data['image'],data['label']
pred = model.predict(data)
pred_cls = pred.argmax(1)
preds.extend(pred_cls.asnumpy())
lables.extend(label.asnumpy())
acc = accuracy_score(preds, lables)
print('\nTest set: Accuracy: {}%\n'.format(acc*100))
return acc

@moxing_wrapper(pre_process=modelarts_pre_process)
def run_eval():
current_dir = os.path.dirname(os.path.abspath(__file__))
device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if device_target == "Ascend":
context.set_context(device_id=get_device_id())
model = DSAN(num_classes=config.nclass)
weight_path = os.path.join(current_dir, config.model_root, config.ckpt_file)
model_dict = ms.load_checkpoint(weight_path)
ms.load_param_into_net(model,model_dict)
dataloaders = load_data(os.path.join(current_dir, config.dataset_path), config.src,
config.tar, config.batch_size)
accuracy = test(model, dataloaders[-1])

if __name__ == '__main__':
run_eval()

49 changes: 49 additions & 0 deletions research/xidian/DSAN/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Export checkpoint file into air, mindir models"""
import os
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, export, context
from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id
from models.DSAN import DSAN_for_export


def modelarts_pre_process():
'''modelarts pre process function.'''
config.file_name = os.path.join(config.output_path, config.file_name)


@moxing_wrapper(pre_process=modelarts_pre_process)
def run_export():
current_dir = os.path.dirname(os.path.abspath(__file__))
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
current_dir = os.path.dirname(os.path.abspath(__file__))
model = DSAN_for_export(num_classes=config.nclass)
weight_path = os.path.join(current_dir, config.model_root, config.ckpt_file)
model_dict = ms.load_checkpoint(weight_path)
ms.load_param_into_net(model,model_dict)
input_arr = Tensor(np.ones([config.batch_size, 3, config.image_height, config.image_width]), ms.float32)
export(model, input_arr, file_name=config.file_name, file_format=config.file_format)


if __name__ == '__main__':
run_export()
105 changes: 105 additions & 0 deletions research/xidian/DSAN/lmmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Lmmd'''
import numpy as np
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops
from mindspore import Tensor
import numpy as np


class LMMD_loss(nn.Cell):
def __init__(self, class_num=31, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None):
super(LMMD_loss, self).__init__()
self.class_num = class_num
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = fix_sigma
self.kernel_type = kernel_type

def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.shape[0]) + int(target.shape[0])
total = ops.concat([source, target], axis=0)
total0 = ops.broadcast_to(ops.expand_dims(total, 0),
(int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
total1 = ops.broadcast_to(ops.expand_dims(total, 1),
(int(total.shape[0]), int(total.shape[0]), int(total.shape[1])))
bandwidth = ops.ReduceSum()(((total0 - total1) ** 2).sum(2)) / (n_samples ** 2 - n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = []
for i in range(kernel_num):
bandwidth_list.append(bandwidth ** (kernel_mul ** i))
kernel_val = [ops.exp(-((total0 - total1) ** 2).sum(2) / bandwidth_temp)
for bandwidth_temp in bandwidth_list]
return sum(kernel_val)

def construct(self, source, target, s_label, t_label, s_label_pre, weight):
batch_size = source.shape[0]
weight_ss, weight_tt, weight_st = self.cal_weight(
s_label, t_label, batch_size=batch_size, class_num=self.class_num)
weight_ss = Tensor.from_numpy(weight_ss)
weight_tt = Tensor.from_numpy(weight_tt)
weight_st = Tensor.from_numpy(weight_st)
kernels = self.guassian_kernel(source, target,
kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
loss = ms.Tensor([0])
if np.sum(ops.isnan(sum(kernels)).asnumpy()):
return loss
SS = kernels[:batch_size, :batch_size]
TT = kernels[batch_size:, batch_size:]
ST = kernels[:batch_size, batch_size:]
loss += ops.ReduceSum()(weight_ss * SS + weight_tt * TT - 2 * weight_st * ST)
loss_cls = nn.NLLLoss()
loss_cls_value = loss_cls(ops.log_softmax(s_label_pre,axis=1), s_label)
loss_final = weight * loss + loss_cls_value
print('loss_lmmd:{},loss_cls:{},weight:{}'.format(loss, loss_cls_value.mean(), weight))
return loss_final

def convert_to_onehot(self, sca_label, class_num=31):
sca_label = np.array(sca_label, dtype=int)
return np.eye(class_num)[sca_label]

def cal_weight(self, s_label, t_label, batch_size=32, class_num=31):
batch_size = s_label.shape[0]
s_sca_label = s_label.asnumpy()
s_vec_label = self.convert_to_onehot(s_sca_label, class_num=self.class_num)
s_sum = np.sum(s_vec_label, axis=0).reshape(1, class_num)
s_sum[s_sum == 0] = 100
s_vec_label = s_vec_label / s_sum
t_sca_label = t_label.max(1)[1].asnumpy()
t_sca_label = t_label.argmax(1).asnumpy()
t_vec_label = t_label.asnumpy()
t_sum = np.sum(t_vec_label, axis=0).reshape(1, class_num)
t_sum[t_sum == 0] = 100
t_vec_label = t_vec_label / t_sum
index = list(set(s_sca_label) & set(t_sca_label))
mask_arr = np.zeros((batch_size, class_num))
mask_arr[:, index] = 1
t_vec_label = t_vec_label * mask_arr
s_vec_label = s_vec_label * mask_arr
weight_ss = np.matmul(s_vec_label, s_vec_label.T)
weight_tt = np.matmul(t_vec_label, t_vec_label.T)
weight_st = np.matmul(s_vec_label, t_vec_label.T)
length = len(index)
if length != 0:
weight_ss = weight_ss / length
weight_tt = weight_tt / length
weight_st = weight_st / length
else:
weight_ss = np.array([0])
weight_tt = np.array([0])
weight_st = np.array([0])
return weight_ss.astype('float32'), weight_tt.astype('float32'), weight_st.astype('float32')
2 changes: 2 additions & 0 deletions research/xidian/DSAN/model_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import config
__all__ = (config)
Loading