Skip to content

Fix bugs and update the code to support Pytorch 1.13.1 #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions M3SDA/code_MSDA_digit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@ PyTorch implementation for **Moment Matching for Multi-Source Domain Adaptation*
The code has been tested on Python 3.6+PyTorch 0.3. To run the training and testing code, use the following script:

## Installation
- Install PyTorch (Works on Version 0.3) and dependencies from http://pytorch.org.
- We strongly suggest installing the requirements in a **Conda** environment except for torchnet, which uses pip.
- Ensure your environment has a GPU since this code was designed for it.
- Install PyTorch (Works on Version 1.13.1+cu116) and dependencies from http://pytorch.org.
- Install Torch vision from the source.
- Install torchnet as follows
- Install torchnet as follows:
```
pip install git+https://github.com/pytorch/tnt.git@master
```
- If an error related to torchnet appears, please install it as:
```
pip install torchnet
```
- Install gdown
- Install scipy and other dependencies as required by your environment.

## Digit-Five Download
Since many researchers have sent us emails for Digit-Five data. We share the Digit-Five dataset we use in our experiments in the following download link:

Expand All @@ -22,6 +31,28 @@ https://drive.google.com/open?id=1A4RJOFj4BJkmliiEL7g9WzNIDUHLxfmm
Keep in mind that the Mnist-M dataset is generated by ourselves, thus this subset may be different from the one released by DANN paper.

If you find the Digit-Five dataset useful for your research, please cite our paper.

## Run your first experiment

1. Navigate to code_MSDA_digit/
```
cd VisionLearningGroup.github.io/M3SDA/code_MSDA_digit/
```
2. Do not forget to change the access permissions to execute the script
```
chmod +x experiment_do.sh
```
3. Create a new folder called data and save there the download files from Digit-Five
```
gdown 1A4RJOFj4BJkmliiEL7g9WzNIDUHLxfmm
unzip Digit-Five.zip
mv Digit-Five data
```
4. Run the SH script
```
./experiment_do.sh usps 100 0 record/usps_MSDA_beta
```

## DomainNet
The DomainNet dataset can be downloaded from the following link:
[http://ai.bu.edu/M3SDA/](http://ai.bu.edu/M3SDA/)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
from scipy.io import loadmat

base_dir = './data'
def load_mnist(scale=True, usps=False, all_use=False):
mnist_data = loadmat(base_dir + '/mnist_data.mat')
if scale:
mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32)
mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32)
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
else:
mnist_train = mnist_data['train_28']
mnist_test = mnist_data['test_28']
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
mnist_train = mnist_train.astype(np.float32)
mnist_test = mnist_test.astype(np.float32)
mnist_train = mnist_train.transpose((0, 3, 1, 2))
mnist_test = mnist_test.transpose((0, 3, 1, 2))
train_label = np.argmax(mnist_labels_train, axis=1)
inds = np.random.permutation(mnist_train.shape[0])
mnist_train = mnist_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnist_labels_test, axis=1)

mnist_train = mnist_train[:25000]
train_label = train_label[:25000]
mnist_test = mnist_test[:25000]
test_label = test_label[:25000]
print('mnist train X shape->', mnist_train.shape)
print('mnist train y shape->', train_label.shape)
print('mnist test X shape->', mnist_test.shape)
print('mnist test y shape->', test_label.shape)

return mnist_train, train_label, mnist_test, test_label
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from scipy.io import loadmat
import numpy as np
import sys

sys.path.append('../utils/')
from utils.utils import dense_to_one_hot
base_dir = './data'
def load_svhn():
svhn_train = loadmat(base_dir + '/svhn_train_32x32.mat')
svhn_test = loadmat(base_dir + '/svhn_test_32x32.mat')
svhn_train_im = svhn_train['X']
svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)

print('svhn train y shape before dense_to_one_hot->', svhn_train['y'].shape)
svhn_label = dense_to_one_hot(svhn_train['y'])
print('svhn train y shape after dense_to_one_hot->',svhn_label.shape)
svhn_test_im = svhn_test['X']
svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32)
svhn_label_test = dense_to_one_hot(svhn_test['y'])
svhn_train_im = svhn_train_im[:25000]
svhn_label = svhn_label[:25000]
svhn_test_im = svhn_test_im[:9000]
svhn_label_test = svhn_label_test[:9000]
print('svhn train X shape->', svhn_train_im.shape)
print('svhn train y shape->', svhn_label.shape)
print('svhn test X shape->', svhn_test_im.shape)
print('svhn test y shape->', svhn_label_test.shape)

return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import torch.utils.data
import torchnet as tnt
from builtins import object
import torchvision.transforms as transforms
from datasets_ import Dataset


class PairedData(object):
def __init__(self, data_loader_A, data_loader_B, data_loader_C, data_loader_D, data_loader_t, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.data_loader_C = data_loader_C
self.data_loader_D = data_loader_D
self.data_loader_t = data_loader_t

self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
self.max_dataset_size = max_dataset_size

def __iter__(self):
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False

self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
self.data_loader_C_iter = iter(self.data_loader_C)
self.data_loader_D_iter = iter(self.data_loader_D)
self.data_loader_t_iter = iter(self.data_loader_t)
self.iter = 0
return self

def __next__(self):
A, A_paths = None, None
B, B_paths = None, None
C, C_paths = None, None
D, D_paths = None, None
t, t_paths = None, None
try:
A, A_paths = next(self.data_loader_A_iter)
except StopIteration:
if A is None or A_paths is None:
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)

try:
B, B_paths = next(self.data_loader_B_iter)
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
try:
C, C_paths = next(self.data_loader_C_iter)
except StopIteration:
if C is None or C_paths is None:
self.stop_C = True
self.data_loader_C_iter = iter(self.data_loader_C)
C, C_paths = next(self.data_loader_C_iter)
try:
D, D_paths = next(self.data_loader_D_iter)
except StopIteration:
if D is None or D_paths is None:
self.stop_D = True
self.data_loader_D_iter = iter(self.data_loader_D)
D, D_paths = next(self.data_loader_D_iter)

try:
t, t_paths = next(self.data_loader_t_iter)
except StopIteration:
if t is None or t_paths is None:
self.stop_t = True
self.data_loader_t_iter = iter(self.data_loader_t)
t, t_paths = next(self.data_loader_t_iter)

if (self.stop_A and self.stop_B and self.stop_C and self.stop_D and self.stop_t) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
raise StopIteration()
else:
self.iter += 1
return {'S1': A, 'S1_label': A_paths,
'S2': B, 'S2_label': B_paths,
'S3': C, 'S3_label': C_paths,
'S4': D, 'S4_label': D_paths,
'T': t, 'T_label': t_paths}


class UnalignedDataLoader():
def initialize(self, source, target, batch_size1, batch_size2, scale=32):
transform = transforms.Compose([
transforms.Resize(scale),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#dataset_source1 = Dataset(source[1]['imgs'], source['labels'], transform=transform)
dataset_source1 = Dataset(source[0]['imgs'], source[0]['labels'], transform=transform)
data_loader_s1 = torch.utils.data.DataLoader(dataset_source1, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s1 = dataset_source1

dataset_source2 = Dataset(source[1]['imgs'], source[1]['labels'], transform=transform)
data_loader_s2 = torch.utils.data.DataLoader(dataset_source2, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s2 = dataset_source2

dataset_source3 = Dataset(source[2]['imgs'], source[2]['labels'], transform=transform)
data_loader_s3 = torch.utils.data.DataLoader(dataset_source3, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s3 = dataset_source3

dataset_source4 = Dataset(source[3]['imgs'], source[3]['labels'], transform=transform)
data_loader_s4 = torch.utils.data.DataLoader(dataset_source4, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s4 = dataset_source4

#for i in range(len(source)):
# dataset_source[i] = Dataset(source[i]['imgs'], source[i]['labels'], transform=transform)
dataset_target = Dataset(target['imgs'], target['labels'], transform=transform)
data_loader_t = torch.utils.data.DataLoader(dataset_target, batch_size=batch_size2, shuffle=True, num_workers=4)


self.dataset_t = dataset_target
self.paired_data = PairedData(data_loader_s1, data_loader_s2, data_loader_s3,data_loader_s4, data_loader_t,
float("inf"))


def name(self):
return 'UnalignedDataLoader'

def load_data(self):
return self.paired_data

def __len__(self):
return min(max(len(self.dataset_s1),len(self.dataset_s2),len(self.dataset_s3), len(self.dataset_s4),len(self.dataset_t)), float("inf"))
2 changes: 1 addition & 1 deletion M3SDA/code_MSDA_digit/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
base_dir = './data'
def load_mnist(scale=True, usps=False, all_use=False):
mnist_data = loadmat(base_dir + '/mnist_data.mat')
if scale
if scale:
mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
Expand Down
4 changes: 2 additions & 2 deletions M3SDA/code_MSDA_digit/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from utils.utils import dense_to_one_hot
base_dir = './data'
def load_svhn():
svhn_train = loadmat(base_dir + '/train_32x32.mat')
svhn_test = loadmat(base_dir + '/test_32x32.mat')
svhn_train = loadmat(base_dir + '/svhn_train_32x32.mat')
svhn_test = loadmat(base_dir + '/svhn_test_32x32.mat')
svhn_train_im = svhn_train['X']
svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)

Expand Down
2 changes: 1 addition & 1 deletion M3SDA/code_MSDA_digit/datasets/unaligned_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __next__(self):
class UnalignedDataLoader():
def initialize(self, source, target, batch_size1, batch_size2, scale=32):
transform = transforms.Compose([
transforms.Scale(scale),
transforms.Resize(scale),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
Expand Down
1 change: 1 addition & 0 deletions M3SDA/code_MSDA_digit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def main():
if not os.path.exists(args.checkpoint_dir):
os.mkdir(args.checkpoint_dir)
if not os.path.exists(args.record_folder):
os.mkdir(args.record_folder.split('/')[0])
os.mkdir(args.record_folder)
if args.eval_only:
solver.test(0)
Expand Down
6 changes: 3 additions & 3 deletions M3SDA/code_MSDA_digit/solver_MSDA.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test(self, epoch, record_file=None, save_model=False):

output1 = self.C1(feat)

test_loss += F.nll_loss(output1, label).data[0]
test_loss += F.nll_loss(output1, label).data
pred1 = output1.data.max(1)[1]
k = label.data.size()[0]
correct1 += pred1.eq(label.data).cpu().sum()
Expand Down Expand Up @@ -354,10 +354,10 @@ def train_MSDA(self, epoch, record_file=None):
if batch_idx % self.interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\t Loss2: {:.6f}\t Discrepancy: {:.6f}'.format(
epoch, batch_idx, 100,
100. * batch_idx / 70000, loss_s_c1.data[0], loss_s_c2.data[0], loss_dis.data[0]))
100. * batch_idx / 70000, loss_s_c1.data, loss_s_c2.data, loss_dis.data))
if record_file:
record = open(record_file, 'a')
record.write('%s %s %s\n' % (loss_dis.data[0], loss_s_c1.data[0], loss_s_c2.data[0]))
record.write('%s %s %s\n' % (loss_dis.data, loss_s_c1.data, loss_s_c2.data))
record.close()
return batch_idx

Expand Down