-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
61 lines (47 loc) · 1.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.utils.data as data
from torchvision import transforms
from PIL import Image
import os
# Image transform
img_transform_source = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
])
img_transform_target = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
def sample_batch(data_iter, source):
img, label = data_iter.next()
# domain labels
batch_size = len(label)
if source:
domain_label = torch.zeros(batch_size).long()
else:
domain_label = torch.ones(batch_size).long()
return img.cuda(), label.cuda(), domain_label.cuda()
class GetLoader(data.Dataset):
def __init__(self, data_root, data_list, transform=None):
self.root = data_root
self.transform = transform
f = open(data_list, 'r')
data_list = f.readlines()
f.close()
self.n_data = len(data_list)
self.img_paths = []
self.img_labels = []
for data in data_list:
self.img_paths.append(data[:-3])
self.img_labels.append(data[-2])
def __getitem__(self, item):
img_paths, labels = self.img_paths[item], self.img_labels[item]
imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')
if self.transform is not None:
imgs = self.transform(imgs)
labels = int(labels)
return imgs, labels
def __len__(self):
return self.n_data