-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
FitFromFeaturized.java
127 lines (115 loc) · 6.43 KB
/
FitFromFeaturized.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.advanced.features.transferlearning.editlastlayer.presave;
import org.deeplearning4j.examples.advanced.features.transferlearning.iterators.FlowerDataSetIteratorFeaturized;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import java.io.IOException;
/**
* @author susaneraly on 3/10/17.
*
* Important:
* Run the class "FeaturizePreSave" before attempting to run this. The outputs at the boundary of the frozen and unfrozen
* vertices of a model are saved. These are referred to as "featurized" datasets in this description.
* On a dataset of about 3000 images which is what is downloaded this can take "a while"
*
* Here we see how the transfer learning helper can be used to fit from a featurized datasets.
* We attempt to train the same model architecture as the one in "EditLastLayerOthersFrozen".
* Since the helper avoids the forward pass through the frozen layers we save on computation time when running multiple epochs.
* In this manner, users can iterate quickly tweaking learning rates, weight initialization etc` to settle on a model that gives good results.
*/
@SuppressWarnings("DuplicatedCode")
public class FitFromFeaturized {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(FitFromFeaturized.class);
private static final String featureExtractionLayer = FeaturizedPreSave.featurizeExtractionLayer;
protected static final long seed = 12345;
protected static final int numClasses = 5;
protected static final int nEpochs = 3;
public static void main(String [] args) throws IOException {
//Import vgg
//Note that the model imported does not have an output layer (check printed summary)
// nor any training related configs (model from keras was imported with only weights and json)
log.info("\n\nLoading org.deeplearning4j.transferlearning.vgg16...\n\n");
ZooModel zooModel = VGG16.builder().build();
ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained();
log.info(vgg16.summary());
//Decide on a fine tune configuration to use.
//In cases where there already exists a setting the fine tune setting will
// override the setting for all layers that are not "frozen".
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.updater(new Nesterovs(3e-5, 0.9))
.seed(seed)
.build();
//Construct a new model with the intended architecture and print summary
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor(featureExtractionLayer) //the specified layer and below are "frozen"
.removeVertexKeepConnections("predictions") //replace the functionality of the final vertex
.addLayer("predictions",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4096).nOut(numClasses)
.weightInit(new NormalDistribution(0,0.2*(2.0/(4096+numClasses)))) //This weight init dist gave better results than Xavier
.activation(Activation.SOFTMAX).build(),
"fc2")
.build();
log.info(vgg16Transfer.summary());
DataSetIterator trainIter = FlowerDataSetIteratorFeaturized.trainIterator();
DataSetIterator testIter = FlowerDataSetIteratorFeaturized.testIterator();
//Instantiate the transfer learning helper to fit and output from the featurized dataset
//The .unfrozenGraph() is the unfrozen subset of the computation graph passed in.
//If using with a UI or a listener attach them directly to the unfrozenGraph instance
//With each iteration updated params from unfrozenGraph are copied over to the original model
TransferLearningHelper transferLearningHelper = new TransferLearningHelper(vgg16Transfer);
log.info(transferLearningHelper.unfrozenGraph().summary());
for (int epoch = 0; epoch < nEpochs; epoch++) {
if (epoch == 0) {
Evaluation eval = transferLearningHelper.unfrozenGraph().evaluate(testIter);
log.info("Eval stats BEFORE fit.....");
log.info(eval.stats()+"\n");
testIter.reset();
}
int iter = 0;
while (trainIter.hasNext()) {
transferLearningHelper.fitFeaturized(trainIter.next());
if (iter % 10 == 0) {
log.info("Evaluate model at iter " + iter + " ....");
Evaluation eval = transferLearningHelper.unfrozenGraph().evaluate(testIter);
log.info(eval.stats());
testIter.reset();
}
iter++;
}
trainIter.reset();
log.info("Epoch #"+epoch+" complete");
}
log.info("Model build complete");
}
}