Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

At text classification example #945

Open
wants to merge 89 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
de5f44c
examples added + changed nd4j backend in pom.xml to run on DGX1
atuzhykov Feb 20, 2020
d65fa9a
examples added + changed nd4j backend in pom.xml to run on DGX1
atuzhykov Feb 21, 2020
c2967df
other small changes
atuzhykov Feb 21, 2020
9e87835
small fix to match cuda version with container
atuzhykov Feb 21, 2020
41530e1
small fix to match cuda version with container
atuzhykov Feb 21, 2020
3198021
lr 1e-3 > 4e-3 (as multiplying batchsize*k, lr*sqrt(k))
atuzhykov Feb 22, 2020
9a1ba54
lr 1e-3 > 4e-3 (as multiplying batchsize*k, lr*sqrt(k))
atuzhykov Feb 23, 2020
95aa639
experiment0
atuzhykov Feb 24, 2020
7501626
experiment1 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
091c386
experiment2 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
86a6518
experiment3 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
d669d6f
experiment4 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
578a186
experiment5 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
fe42967
experiment6 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
302f7bb
experiment7 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
bb933c5
experiment8 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
000f7d8
experiment9 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
79014f8
experiment9 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
4a5cffd
experiment10 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
8ef7519
experiment10 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
b99d59a
experiment11 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
1aabba1
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
dc3f3b3
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
f16d8ac
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
36ae8ee
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
0de07d8
experiment13 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
f2eece6
experiment14 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
ff91a96
baseline conf + LengthHandling.FIXED_LENGTH=256
atuzhykov Feb 27, 2020
2e99c55
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm
atuzhykov Feb 28, 2020
e107c2c
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm+lr1e-4
atuzhykov Feb 28, 2020
c09746d
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm_256
atuzhykov Feb 28, 2020
2b6414a
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm_256_l…
atuzhykov Feb 28, 2020
834c338
base_conf+bidir_LSTM_256_layersize_Adam_lr1e-3_SGD_lr1e-3_for_EmbdLayer
Mar 2, 2020
d6b2644
base_conf+bidir_LSTM_256_layersize_Adam_lr1e-3_SGD_lr1e-3_for_EmbdLayer
atuzhykov Mar 2, 2020
615657f
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
77d6c27
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
d4ef045
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
8a74305
base_conf+3x_bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
1668928
base_conf+3xbidir_LSTM_256_layersize_Adam_lr1e-3_l21e-5
atuzhykov Mar 2, 2020
022c4dc
base_conf+3x_bidir_LSTM_256_layersize_Adam_Sheduled_lr
atuzhykov Mar 2, 2020
c7f5393
base_conf+2x_bidir_LSTM_256_Adam_lr1e-3_lstm_dropout_075
atuzhykov Mar 3, 2020
ccecf71
prefinal examples
atuzhykov Mar 4, 2020
51e30c3
examples added + changed nd4j backend in pom.xml to run on DGX1
atuzhykov Feb 20, 2020
b42580e
examples added + changed nd4j backend in pom.xml to run on DGX1
atuzhykov Feb 21, 2020
aa0c749
other small changes
atuzhykov Feb 21, 2020
5f5625e
small fix to match cuda version with container
atuzhykov Feb 21, 2020
61ca555
small fix to match cuda version with container
atuzhykov Feb 21, 2020
f6c95c2
lr 1e-3 > 4e-3 (as multiplying batchsize*k, lr*sqrt(k))
atuzhykov Feb 22, 2020
a20525a
lr 1e-3 > 4e-3 (as multiplying batchsize*k, lr*sqrt(k))
atuzhykov Feb 23, 2020
6d60af5
experiment0
atuzhykov Feb 24, 2020
5b1082c
experiment1 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
c0d363e
experiment2 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
b92aa8f
experiment3 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
407feae
experiment4 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
9adc5e4
experiment5 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
2e73eae
experiment6 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
88ea863
experiment7 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
f4aadbe
experiment8 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
772d38e
experiment9 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
096eca4
experiment9 (notes belong to commit name are here http://tiny.cc/yashkz)
atuzhykov Feb 25, 2020
b2f6510
experiment10 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
c5de18f
experiment10 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
145e1fd
experiment11 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 25, 2020
2a72f0a
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
3f07a38
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
6256e85
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
da14588
experiment12 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
cb26a75
experiment13 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
f0e1241
experiment14 (notes belong to commit name are here http://tiny.cc/yas…
atuzhykov Feb 26, 2020
c4f9d9a
baseline conf + LengthHandling.FIXED_LENGTH=256
atuzhykov Feb 27, 2020
02d47fd
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm
atuzhykov Feb 28, 2020
61b63f8
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm+lr1e-4
atuzhykov Feb 28, 2020
820eda5
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm_256
atuzhykov Feb 28, 2020
2c757c0
baselineconf+LengthHandling.FIXED_LENGTH=256+Bidirectional_lstm_256_l…
atuzhykov Feb 28, 2020
53efeec
base_conf+bidir_LSTM_256_layersize_Adam_lr1e-3_SGD_lr1e-3_for_EmbdLayer
Mar 2, 2020
1bbb9e0
base_conf+bidir_LSTM_256_layersize_Adam_lr1e-3_SGD_lr1e-3_for_EmbdLayer
atuzhykov Mar 2, 2020
880cd30
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
9555d55
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
8aab2fb
base_conf+bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
542db86
base_conf+3x_bidir_LSTM_256_layersize_Nadam_lr1e-3
atuzhykov Mar 2, 2020
46dffc0
base_conf+3xbidir_LSTM_256_layersize_Adam_lr1e-3_l21e-5
atuzhykov Mar 2, 2020
c5a979a
base_conf+3x_bidir_LSTM_256_layersize_Adam_Sheduled_lr
atuzhykov Mar 2, 2020
6e390eb
base_conf+2x_bidir_LSTM_256_Adam_lr1e-3_lstm_dropout_075
atuzhykov Mar 3, 2020
7945508
prefinal examples
atuzhykov Mar 4, 2020
74162ff
prefinal
atuzhykov Mar 4, 2020
13a2392
changed package and class name, added trained model URL
atuzhykov Mar 5, 2020
5b1b710
fixed required changes
atuzhykov Mar 6, 2020
a22956d
fixed new round of required changes
atuzhykov Mar 11, 2020
5e7df4f
small issue belong to match BertIterator and DataSetIterator in Evalu…
atuzhykov Mar 11, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,280 @@
/*******************************************************************************
* Copyright (c) Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.nlp.sentencepiecernnexample;


import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.iterator.BertIterator;
import org.deeplearning4j.iterator.provider.FileLabeledSentenceProvider;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.deeplearning4j.examples.utilities.DataUtilities;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.*;


/**
* Example: Given a movie review (raw text), classify that movie review as either positive or negative based on the words it contains.
* This is done by combining BertIterator and a current neural network model: embedding sequence layer with nIn = vocabulary size,
* two bidirectional LSTM layers, followed by global pooling layer and output with nOu = 2 (2 classes: positive and negative reviews).
* As far model is predisposed to overfitting we also add l2 regularization and dropout for certain layers.
* To prepare reviews we use BertIterator, which is MultiDataSetIterator for training BERT (Transformer) models.
* We congigure BertIterator for supervised sequence classification:
* 0. As tokenizer we use BertWordPieceTokenizerFactory with provided BERT BASE UNCASED vocabulary.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's improve this slightly, add another line under 0.:
BertIterator and BertWordPieceTokenizer implement the Word Piece sub-word tokenization algorithm, with a vocabulary size of 30522 tokens.

* 1. We handle length of sequence to fixed - trim longer sequences and pad shorter to 256 words.
* 2. Sentence provider get as a reviewFilesMap, connstructed from dataset, described below.
* 3. FeatureArrays configures what arrays should be included: <b>INDICES_MASK</b> means
* indices array and mask array only, no segment ID array; returns 1 feature array, 1 feature mask array (plus labels).
* 4. As task we specify BertIterator.Task.SEQ_CLASSIFICATION, which means sequence clasification.
* Training data is the "Large Movie Review Dataset" from http://ai.stanford.edu/~amaas/data/sentiment/
* This data set contains 25,000 training reviews + 25,000 testing reviews
* <p>
* Process:
* 0. Automatic on first run of example: Download data (movie reviews) + extract and download BERT-BASE UNCASED vocabulary file.
* 1. BertWordPieceTokenizerFactory initializing with provided vocab.
* 2. Configuring MiltiLayerNetwork.
* 3. Setting of BertIterator and getting train and test data with followed by preprocessor.
* 4. Train network
* <p>
* With the current configuration, gives approx. 86% accuracy after 19 epochs. Better performance may be possible with
* additional tuning.
* <p>
* NOTE: You may download already trained defined below model for your own inference
* https://dl4jdata.blob.core.windows.net/dl4j-examples/models/sentencepiece_rnn_example_model.zip
* <p>
* Recommended papers:
* 0. SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing
* https://arxiv.org/abs/1808.06226
* 1. Attention Is All You Need
* https://arxiv.org/abs/1706.03762
* @author Andrii Tuzhykov
*/
public class SentencePieceRNNExample {


/**
* Data URL for downloading
*/
public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";
/**
* Bert Base Uncased Vocabulary URL
*/
public static final String VOCAB_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt";

/**
* Location to save and extract the training/testing data
*/
public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");


public static void main(String[] args) throws Exception {

//Download and extract data
downloadData();


final int seed = 0;
//Seed for reproducibility
String pathToVocab = DATA_PATH + "vocab.txt";
// Path to vocab

// BertWordPieceTokenizerFactory initialized with given vocab
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(new File(pathToVocab), true, true, StandardCharsets.UTF_8);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.updater(new Adam(1e-3))
.l2(1e-5)
.weightInit(WeightInit.XAVIER)
.list()
// matching EmbeddingSequenceLayer outputs with Bidirectional LSTM inputs
.setInputType(InputType.recurrent(1))
// initialized weights with normal distribution, amount of inputs according to vocab size and off L2 for this layer
.layer(0, new EmbeddingSequenceLayer.Builder().weightInit(new NormalDistribution(0, 1)).l2(0)
.hasBias(true).nIn(t.getVocab().size()).nOut(128).build())
// two Bidirectional LSTM layers in a row with dropout and tanh as activation function
.layer(new Bidirectional(new LSTM.Builder().nOut(256)
.dropOut(0.8).activation(Activation.TANH).build()))
.layer(new Bidirectional(new LSTM.Builder().nOut(256)
.dropOut(0.8).activation(Activation.TANH).build()))
.layer(new GlobalPoolingLayer(PoolingType.MAX))
// defining last layer with 2 outputs (2 classes - positive and negative),
// small dropout to avoid overfitting and MCXENT loss function
.layer(new OutputLayer.Builder().nOut(2)
.dropOut(0.97).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();

// Getting train and test BertIterators for both: test and train,
// changing argument isTraining: true to get train and false to get test respectively
BertIterator train = getBertDataSetIterator(true, t);
BertIterator test = getBertDataSetIterator(false, t);

// Preprocessor for DataType matching; can be removed after 1.0.0-beta7 release.
MultiDataSetPreProcessor mdsPreprocessor = new MultiDataSetPreProcessor() {
@Override
public void preProcess(MultiDataSet multiDataSet) {
multiDataSet.setFeaturesMaskArray(0, multiDataSet.getFeaturesMaskArray(0).castTo(DataType.FLOAT));
}
};

// Applying preprocessor for both: train and test datasets
train.setPreProcessor(mdsPreprocessor);
test.setPreProcessor(mdsPreprocessor);

// initialize MultiLayerNetwork instance with described above configuration
MultiLayerNetwork net = new MultiLayerNetwork(conf);


//Initialize the user interface backend
UIServer uiServer = UIServer.getInstance();

//Configure where the network information (gradients, activations, score vs. time etc) is to be stored
//Then add the StatsListener to collect this information from the network, as it trains
StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats-" + System.currentTimeMillis() + ".dl4j"));
int listenerFrequency = 20;
net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(50));
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
uiServer.attach(statsStorage);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's comment out the UI by default, as it adds some overhead (slows down training a bit). Users can uncomment it if they want to run it with UI. That would look like this:

        /*
        //Uncomment this section to run the example with the user interface
        UIServer uiServer = UIServer.getInstance();

        //Configure where the network information (gradients, activations, score vs. time etc) is to be stored
        //Then add the StatsListener to collect this information from the network, as it trains
        StatsStorage statsStorage = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "ui-stats-" + System.currentTimeMillis() + ".dl4j"));
        int listenerFrequency = 20;
        net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(50));
        //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
        uiServer.attach(statsStorage);
        */
        
        net.setListeners(new ScoreIterationListener(50));



// Setting to train net for 19 epochs (note: previous net state persist after each epoch (i.e. cycle iteration))
for (int i = 1; i <= 19; i++) {

net.fit(train);

// Get and print accuracy, precision, recall & F1 and confusion matrix
Evaluation eval = net.doEvaluation(test, new Evaluation[]{new Evaluation()})[0];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MultiLayerNetwork, we can use net.evaluate(test)

System.out.println("===== Evaluation at training iteration " + i + " =====");
System.out.println(eval.stats());
}

}

/**
* Get BertIterator instance.
*
* @param isTraining specifies which dataset iterator we want to get: train or test.
* @param t BertWordPieceTokenizerFactory initialized with provided vocab.
* @return BertIterator with specified parameters.
*/
public static BertIterator getBertDataSetIterator(boolean isTraining, BertWordPieceTokenizerFactory t) {

String path = FilenameUtils.concat(DATA_PATH, (isTraining ? "aclImdb/train/" : "aclImdb/test/"));
String positiveBaseDir = FilenameUtils.concat(path, "pos");
String negativeBaseDir = FilenameUtils.concat(path, "neg");
Random rng = new Random(42);

File filePositive = new File(positiveBaseDir);
File fileNegative = new File(negativeBaseDir);

Map<String, List<File>> reviewFilesMap = new HashMap<>();
reviewFilesMap.put("Positive", Arrays.asList(Objects.requireNonNull(filePositive.listFiles())));
reviewFilesMap.put("Negative", Arrays.asList(Objects.requireNonNull(fileNegative.listFiles())));


BertIterator b = BertIterator.builder()
.tokenizer(t)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 256)
.minibatchSize(32)
.sentenceProvider(new FileLabeledSentenceProvider(reviewFilesMap, rng))
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();


return b;
}

public static void downloadData() throws Exception {
//Create directory if required
File directory = new File(DATA_PATH);
if (!directory.exists()) directory.mkdir();

//Download file:
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
File archiveFile = new File(archizePath);
String extractedPath = DATA_PATH + "aclImdb";
File extractedFile = new File(extractedPath);

if (!archiveFile.exists()) {
System.out.println("Starting data download (80MB)...");
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
//Extract tar.gz file to output directory
DataUtilities.extractTarGz(archizePath, DATA_PATH);
} else {
//Assume if archive (.tar.gz) exists, then data has already been extracted
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
if (!extractedFile.exists()) {
//Extract tar.gz file to output directory
DataUtilities.extractTarGz(archizePath, DATA_PATH);
} else {
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
}
}


// Download Bert Base Uncased Vocab
String vocabPath = DATA_PATH + "vocab.txt";
File vocabFile = new File(vocabPath);

if (!vocabFile.exists()) {
try (BufferedInputStream inputStream = new BufferedInputStream(new URL(VOCAB_URL).openStream());
FileOutputStream file = new FileOutputStream(DATA_PATH + "vocab.txt")) {
byte data[] = new byte[1024];
int byteContent;
while ((byteContent = inputStream.read(data, 0, 1024)) != -1) {
file.write(data, 0, byteContent);
}
} catch (IOException e) {
System.out.println("Something went wrong getting Bert Base Vocabulary");
}

} else {
System.out.println("Vocab file already exists at " + vocabFile.getAbsolutePath());
}

}
}

2 changes: 1 addition & 1 deletion pom.xml
Expand Up @@ -28,7 +28,7 @@
<properties>
<!-- Change the nd4j.backend property to nd4j-cuda-9.2-platform,nd4j-cuda-10.0-platform or nd4j-cuda-10.1-platform to use CUDA GPUs -->
<nd4j.backend>nd4j-native-platform</nd4j.backend>
<!-- <nd4j.backend>nd4j-cuda-10.2-platform</nd4j.backend>-->
<!-- <nd4j.backend>nd4j-cuda-10.0-platform</nd4j.backend>-->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this commented out with 10.2

<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<shadedClassifier>bin</shadedClassifier>

Expand Down