-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
ImagePipelineExample.java
129 lines (110 loc) · 6.4 KB
/
ImagePipelineExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datapipelineexamples.formats.image;
import org.deeplearning4j.datapipelineexamples.utils.DownloaderUtility;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ShowImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
import java.util.Random;
/**
* Load images and augment them to build an image pipeline
* @author susaneraly
*/
public class ImagePipelineExample {
//Images are of format given by allowedExtension -
private static final String [] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
private static final long seed = 12345;
private static final Random randNumGen = new Random(seed);
private static final int height = 50;
private static final int width = 50;
private static final int channels = 3;
public static String dataLocalPath;
public static void main(String[] args) throws Exception {
dataLocalPath = DownloaderUtility.DATAEXAMPLES.Download();
//DIRECTORY STRUCTURE:
//Images in the dataset have to be organized in directories by class/label.
//In this example there are ten images in three classes
//Here is the directory structure
// parentDir
// / | \
// / | \
// labelA labelB labelC
//
//Set your data up like this so that labels from each label/class live in their own directory
//And these label/class directories live together in the parent directory
//
//
File parentDir=new File(dataLocalPath,"ImagePipeline/");
//Files in directories under the parent dir that have "allowed extensions" split needs a random number generator for reproducibility when splitting the files into train and test
FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
//You do not have to manually specify labels. This class (instantiated as below) will
//parse the parent dir and use the name of the subdirectories as label/class names
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
//The balanced path filter gives you fine tune control of the min/max cases to load for each class
//Below is a bare bones version. Refer to javadoc for details
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);
//Split the image files into train and test. Specify the train test split as 80%,20%
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 80, 20);
InputSplit trainData = filesInDirSplit[0];
//InputSplit testData = filesInDirSplit[1]; //The testData is never used in the example, commenting out.
//Specifying a new record reader with the height and width you want the images to be resized to.
//Note that the images in this example are all of different size
//They will all be resized to the height and width specified below
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
//Often there is a need to transforming images to artificially increase the size of the dataset
//DataVec has built in powerful features from OpenCV
//You can chain transformations as shown below, write your own classes that will say detect a face and crop to size
/*ImageTransform transform = new MultiImageTransform(randNumGen,
new CropImageTransform(10), new FlipImageTransform(),
new ScaleImageTransform(10), new WarpImageTransform(10));
*/
//You can use the ShowImageTransform to view your images
//Code below gives you a look before and after, for a side by side comparison
ImageTransform transform = new MultiImageTransform(randNumGen,new ShowImageTransform("Display - before "));
//Initialize the record reader with the train data and the transform chain
recordReader.initialize(trainData,transform);
int outputNum = recordReader.numLabels();
//convert the record reader to an iterator for training - Refer to other examples for how to use an iterator
int batchSize = 10; // Minibatch size. Here: The number of images to fetch for each call to dataIter.next().
int labelIndex = 1; // Index of the label Writable (usually an IntWritable), as obtained by recordReader.next()
// List<Writable> lw = recordReader.next();
// then lw[0] = NDArray shaped [1,3,50,50] (1, channels, height, width)
// lw[0] = label as integer.
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, outputNum);
while (dataIter.hasNext()) {
DataSet ds = dataIter.next();
System.out.println(ds);
try {
Thread.sleep(3000); //1000 milliseconds is one second.
} catch(InterruptedException ex) {
Thread.currentThread().interrupt();
}
}
}
}