/
data_loader.py
107 lines (81 loc) · 4.01 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import random
from random import shuffle
import numpy as np
import torch
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
class ImageFolder(data.Dataset):
def __init__(self, root, image_size=224, mode='train', augmentation_prob=0.4):
"""Initializes image paths and preprocessing module."""
self.root = root
# GT : Ground Truth
self.GT_paths_root = root + '/mask'
self.image_paths_root = root + '/image'
self.image_paths = [os.path.join(self.image_paths_root, img) for img in os.listdir(self.image_paths_root) if '.DS' not in img]
self.image_size = image_size
self.mode = mode
self.RotationDegree = [0, 90, 180, 270]
self.augmentation_prob = augmentation_prob
print("image count in {} path :{}".format(self.mode, len(self.image_paths)))
def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
filename = image_path.split('/')[-1]
GT_path = os.path.join(self.GT_paths_root, filename.split('.')[0] + '.png')
image = Image.open(image_path)
GT = Image.open(GT_path)
aspect_ratio = image.size[1] / image.size[0]
Transform = []
ResizeRange = random.randint(300, 320)
Transform.append(T.Resize((int(ResizeRange * aspect_ratio), ResizeRange)))
p_transform = random.random()
if (self.mode == 'train') and p_transform <= self.augmentation_prob:
RotationDegree = random.randint(0, 3)
RotationDegree = self.RotationDegree[RotationDegree]
if (RotationDegree == 90) or (RotationDegree == 270):
aspect_ratio = 1 / aspect_ratio
Transform.append(T.RandomRotation((RotationDegree, RotationDegree)))
RotationRange = random.randint(-10, 10)
Transform.append(T.RandomRotation((RotationRange, RotationRange)))
CropRange = random.randint(250, 270)
Transform.append(T.CenterCrop((int(CropRange * aspect_ratio), CropRange)))
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
ShiftRange_left = random.randint(0, 20)
ShiftRange_upper = random.randint(0, 20)
ShiftRange_right = image.size[0] - random.randint(0, 20)
ShiftRange_lower = image.size[1] - random.randint(0, 20)
image = image.crop(box=(ShiftRange_left, ShiftRange_upper, ShiftRange_right, ShiftRange_lower))
GT = GT.crop(box=(ShiftRange_left, ShiftRange_upper, ShiftRange_right, ShiftRange_lower))
if random.random() < 0.5:
image = F.hflip(image)
GT = F.hflip(GT)
if random.random() < 0.5:
image = F.vflip(image)
GT = F.vflip(GT)
Transform = T.ColorJitter(brightness=0.2, contrast=0.2, hue=0.02)
image = Transform(image)
Transform = []
Transform.append(T.Resize((int(256 * aspect_ratio) - int(256 * aspect_ratio) % 16, 256)))
Transform.append(T.ToTensor())
Transform = T.Compose(Transform)
image = Transform(image)
GT = Transform(GT)
Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
image = Norm_(image)
return image, GT
def __len__(self):
"""Returns the total number of font files."""
return len(self.image_paths)
def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train', augmentation_prob=0.4):
"""Builds and returns Dataloader."""
dataset = ImageFolder(root=image_path, image_size=image_size, mode=mode, augmentation_prob=augmentation_prob)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader