-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
SVMLightExample.java
126 lines (105 loc) · 5.7 KB
/
SVMLightExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datapipelineexamples.formats.svmlight;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datapipelineexamples.utils.DownloaderUtility;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
public class SVMLightExample {
private static Logger log = LoggerFactory.getLogger(SVMLightExample.class);
public static String dataLocalPath;
public static void main(String[] args) throws Exception {
int numOfFeatures = 784; // For MNIST data set, each row is a 1D expansion of a handwritten digits picture of size 28x28 pixels = 784
int numOfClasses = 10; // 10 classes (types of senders) in the data set. Zero indexing. Classes have integer values 0, 1 or 2 ... 9
int batchSize = 10; // 1000 examples, with batchSize is 10, around 100 iterations per epoch
int printIterationsNum = 20; // print score every 20 iterations
int hiddenLayer1Num = 200;
long seed = 42;
int nEpochs = 4;
dataLocalPath = DownloaderUtility.DATAEXAMPLES.Download();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, numOfFeatures);
SVMLightRecordReader trainRecordReader = new SVMLightRecordReader();
trainRecordReader.initialize(config, new FileSplit(new File(dataLocalPath,"MnistSVMLightExample/mnist_svmlight_train_1000.txt")));
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, numOfFeatures, numOfClasses);
SVMLightRecordReader testRecordReader = new SVMLightRecordReader();
testRecordReader.initialize(config, new FileSplit(new File(dataLocalPath,"MnistSVMLightExample/mnist_svmlight_test_100.txt")));
DataSetIterator testIter = new RecordReaderDataSetIterator(testRecordReader, batchSize, numOfFeatures, numOfClasses);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.updater(Adam.builder().learningRate(0.02).beta1(0.9).beta2(0.999).build())
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder().nIn(numOfFeatures).nOut(hiddenLayer1Num)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(hiddenLayer1Num).nOut(numOfClasses).build())
.build();
//run the model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(printIterationsNum));
for ( int n = 0; n < nEpochs; n++) {
model.fit(trainIter);
log.info(String.format("Epoch %d finished training", n + 1));
// evaluate the model on test data, once every second epoch
if ((n + 1) % 2 == 0) {
//evaluate the model on the test set
Evaluation eval = new Evaluation(numOfClasses);
testIter.reset();
while(testIter.hasNext()) {
DataSet t = testIter.next();
INDArray features = t.getFeatures();
INDArray labels = t.getLabels();
INDArray predicted = model.output(features, false);
eval.eval(labels, predicted);
}
log.info(String.format("Evaluation on test data - [Epoch %d] [Accuracy: %.3f, P: %.3f, R: %.3f, F1: %.3f] ",
n + 1, eval.accuracy(), eval.precision(), eval.recall(), eval.f1()));
log.info(eval.stats());
}
}
System.out.println("Finished...");
}
}