-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Ex05_SamplingBaseInputSplitExample.java
101 lines (86 loc) · 4.13 KB
/
Ex05_SamplingBaseInputSplitExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
/*******************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datapipelineexamples.loading;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.PathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.deeplearning4j.datapipelineexamples.utils.DownloaderUtility;
import java.io.File;
import java.util.Iterator;
import java.util.Random;
/**
* {@link org.datavec.api.split.BaseInputSplit} and its implementation provides a
* {@link org.datavec.api.split.BaseInputSplit#sample(PathFilter, double...)} method that is very useful for generating
* several {@link org.datavec.api.split.InputSplit}s from the main split.
* <p>
* This can be used for dividing your dataset into several subsets. For example, into training, validation and testing.
* <p>
* The {@link PathFilter} is useful for filtering the main split before generating the input splits array.
* The second argument is a list of weights, which indicate a percentage of each input split.
* <p>
* The samples are divided in the following way -> totalSamples * (weight1/totalWeightSum, weight2/totalWeightSum, ...,
* weightN/totalWeightSum)
* <p>
* {@link PathFilter} has two default implementations,
* {@link org.datavec.api.io.filters.RandomPathFilter} that simple randomizes the order of paths in an array.
* and
* {@link org.datavec.api.io.filters.BalancedPathFilter} that randomizes the order of paths in an array and removes
* paths randomly to have the same number of paths for each label. Further interlaces the paths on output based on
* their labels, to obtain easily optimal batches for training.
* <p>
* Their usages are shown here.
*/
public class Ex05_SamplingBaseInputSplitExample {
public static String dataLocalPath;
public static void main(String[] args) throws Exception {
dataLocalPath = DownloaderUtility.INPUTSPLIT.Download();
FileSplit fileSplit = new FileSplit(new File(dataLocalPath, "files"));
//Sampling with a RandomPathFilter
InputSplit[] inputSplits1 = fileSplit.sample(
new RandomPathFilter(new Random(123), null),
10, 10, 10, 10, 10);
System.out.println(String.format(("Random filtered splits -> Total(%d) = Splits of (%s)"), fileSplit.length(),
String.join(" + ", () -> new InputSplitLengthIterator(inputSplits1))));
//Sampling with a BalancedPathFilter
InputSplit[] inputSplits2 = fileSplit.sample(
new BalancedPathFilter(new Random(123), null, new ParentPathLabelGenerator()),
10, 10, 10, 10, 10);
System.out.println(String.format(("Balanced Splits are: %s"),
String.join(" + ", () -> new InputSplitLengthIterator(inputSplits2))));
}
private static class InputSplitLengthIterator implements Iterator<CharSequence> {
InputSplit[] inputSplits;
int i;
public InputSplitLengthIterator(InputSplit[] inputSplits) {
this.inputSplits = inputSplits;
this.i = 0;
}
@Override
public boolean hasNext() {
return i < inputSplits.length;
}
@Override
public CharSequence next() {
return String.valueOf(inputSplits[i++].length());
}
}
}