-
Notifications
You must be signed in to change notification settings - Fork 1
/
datamodules.py
136 lines (92 loc) · 5.54 KB
/
datamodules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
from utils import *
from PIL import Image
import mirabest
#config_dict, config = parse_config('config1.txt')
#data
#transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((datamean ,), (datastd,))])
class MiraBestDataModule():
def __init__(self, config_dict, config, random_seed = 15):
self.batch_size = config_dict['training']['batch_size']
self.validation_split = config_dict['training']['frac_val']
self.dataset = config_dict['data']['dataset']
self.path = Path(config_dict['data']['datadir'])
self.datamean = config_dict['data']['datamean']
self.datastd = config_dict['data']['datastd']
self.augment = 'True' #config_dict['data']['augment'] #more useful to define while calling train/val loader?
self.imsize = config_dict['training']['imsize']
self.random_seed = random_seed
def transforms(self, aug):
if(aug == 'False'):
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((self.datamean ,), (self.datastd,))])
else:
print('AUGMENTING')
#crop, pad(reflect), rotate, to tensor, normalise
#change transform to transform_aug for the training and validation set only:train_data_confident
crop = transforms.CenterCrop(self.imsize)
pad = transforms.Pad((0, 0, 1, 1), fill=0)
transform = transforms.Compose([crop,pad,
#transforms.RandomRotation(360, resample=Image.BILINEAR, expand=False),
transforms.RandomRotation(360, interpolation=transforms.InterpolationMode.NEAREST, expand=False),
transforms.ToTensor(),
transforms.Normalize((self.datamean ,), (self.datastd,)),
])
return transform
def train_val_loader(self):
transform = self.transforms(self.augment)
if(self.dataset =='MBFRConf'):
train_data_confident = mirabest.MBFRConfident(self.path, train=True,
transform=transform, target_transform=None,
download=True)
train_data_conf = train_data_confident
elif(self.dataset == 'MBFRConf+Uncert'):
train_data_confident = mirabest.MBFRConfident(self.path, train=True,
transform=transform, target_transform=None,
download=True)
train_data_uncertain = mirabest.MBFRUncertain(self.path, train=True,
transform=transform, target_transform=None,
download=True)
#concatenate datasets
train_data_conf= torch.utils.data.ConcatDataset([train_data_confident, train_data_uncertain])
#train-valid
dataset_size = len(train_data_conf)
indices = list(range(dataset_size))
split = int(dataset_size*0.2) #int(np.floor(validation_split * dataset_size))
shuffle_dataset = True
random_seed = self.random_seed
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(train_data_conf, batch_size=self.batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(train_data_conf, batch_size=self.batch_size, sampler=valid_sampler)
return train_loader, validation_loader, train_sampler, valid_sampler
def test_loader(self):
#no augmentation for test_loader
transform = self.transforms(aug=False)
if(self.dataset =='MBFRConf'):
test_data_confident = mirabest.MBFRConfident(self.path, train=False,
transform=transform, target_transform=None,
download=False)
test_data_conf = test_data_confident
elif(self.dataset == 'MBFRConf+Uncert'):
#confident
test_data_confident = mirabest.MBFRConfident(self.path, train=False,
transform=transform, target_transform=None,
download=True)
#uncertain
test_data_uncertain = mirabest.MBFRUncertain(self.path, train=False,
transform=transform, target_transform=None,
download=True)
#concatenate datasets
test_data_conf = torch.utils.data.ConcatDataset([test_data_confident, test_data_uncertain])
test_loader = torch.utils.data.DataLoader(dataset=test_data_conf, batch_size=self.batch_size,shuffle=True)
return test_loader