Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add end-to-end word2vec example #966

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,186 @@
package org.deeplearning4j.examples.nlp.word2vec;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;

public class ImdbSentimentIterator implements DataSetIterator {
private final int batchSize;
private final int truncateLength;

private int cursor = 0;
private final File[] positiveFiles;
private final File[] negativeFiles;

private final TokenizerFactory tokenizerFactory;
private final VocabCache vocab;

ImdbSentimentIterator(String dataDirectory, VocabCache vocab, int batchSize, int truncateLength, boolean train){
this.batchSize = batchSize;
this.vocab = vocab;

File p = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/pos/") + "/");
File n = new File(FilenameUtils.concat(dataDirectory, "aclImdb/" + (train ? "train" : "test") + "/neg/") + "/");
positiveFiles = p.listFiles();
negativeFiles = n.listFiles();

this.truncateLength = truncateLength;

tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
}

@Override
public DataSet next(int i) {
if (cursor >= positiveFiles.length + negativeFiles.length) throw new NoSuchElementException();
try{
return nextDataSet(i);
}catch(IOException e){
throw new RuntimeException(e);
}
}

private DataSet nextDataSet(int num) throws IOException {
//First: load reviews to String. Alternate positive and negative reviews
List<String> reviews = new ArrayList<>(num);
boolean[] positive = new boolean[num];
for( int i=0; i<num && cursor<totalExamples(); i++ ){
if(cursor % 2 == 0){
//Load positive review
int posReviewNumber = cursor / 2;
String review = FileUtils.readFileToString(positiveFiles[posReviewNumber], (Charset)null);
reviews.add(review);
positive[i] = true;
} else {
//Load negative review
int negReviewNumber = cursor / 2;
String review = FileUtils.readFileToString(negativeFiles[negReviewNumber], (Charset)null);
reviews.add(review);
positive[i] = false;
}
cursor++;
}

//Second: tokenize reviews and filter out unknown words
List<List<Integer>> allTokensIndex = new ArrayList<>(reviews.size());
for(String s : reviews){
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<Integer> tokensIndex = new ArrayList<>();
for(String t : tokens ){
if(vocab.hasToken(t)){
tokensIndex.add(vocab.indexOf(t)+1);
}else{
tokensIndex.add(0);
}
}
allTokensIndex.add(tokensIndex);
}

//Create data for training
INDArray features = Nd4j.create(reviews.size(), 1, this.truncateLength);
INDArray labels = Nd4j.create(reviews.size(), 2); //Two labels: positive or negative

//Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
INDArray featuresMask = Nd4j.zeros(reviews.size(), this.truncateLength);

for( int i=0; i<reviews.size(); i++ ){
List<Integer> tokensIndex = allTokensIndex.get(i);

int seqLength = Math.min(tokensIndex.size(), this.truncateLength);

int startSeqIndex = this.truncateLength - seqLength;

// Assign token index into feature array
features.put(
new INDArrayIndex[] {
NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(startSeqIndex, this.truncateLength)
},
Nd4j.create(tokensIndex.subList(0, seqLength)));

// Assign "1" to each position where a feature is present
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.interval(startSeqIndex, this.truncateLength)).assign(1);

int idx = (positive[i] ? 0 : 1);
labels.putScalar(new int[]{i,idx},1.0); //Set label: [0,1] for negative, [1,0] for positive
}

return new DataSet(features,labels,featuresMask,null);
}

private int totalExamples() {
return positiveFiles.length + negativeFiles.length;
}

@Override
public int inputColumns() {
return 0;
}

@Override
public int totalOutcomes() {
return 2;
}

@Override
public boolean resetSupported() {
return true;
}

@Override
public boolean asyncSupported() {
return true;
}

@Override
public void reset() {
cursor = 0;
}

@Override
public int batch() {
return batchSize;
}

@Override
public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
throw new UnsupportedOperationException();
}

@Override
public DataSetPreProcessor getPreProcessor() {
return null;
}

@Override
public List<String> getLabels() {
return Arrays.asList("positive","negative");
}

@Override
public boolean hasNext() {
return cursor < totalExamples();
}

@Override
public DataSet next() {
return next(batchSize);
}
}
@@ -0,0 +1,210 @@
package org.deeplearning4j.examples.nlp.word2vec;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.examples.utilities.DataUtilities;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.sentenceiterator.FileSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.net.URL;
import java.util.Arrays;
import java.util.List;

public class Word2VecInEmbeddingLayer {

private static Logger log = LoggerFactory.getLogger(Word2VecInEmbeddingLayer.class);

public static final String DATA_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz";

public static final String DATA_PATH = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_w2vSentiment/");

public static final String WORD2VEC_PATH = FilenameUtils.concat(DATA_PATH, "wordvectors.txt");

public static final int embeddingSize = 100;

public static WordVectors wordVectors = null;

public static void main(String[] args) throws Exception {
downloadData();

File word2vecFile = new File(WORD2VEC_PATH);
if (!word2vecFile.exists()){
trainingWord2Vec();
}

wordVectors = WordVectorSerializer.readWord2VecModel(WORD2VEC_PATH);

VocabCache vocab = wordVectors.vocab();

ImdbSentimentIterator trainIter = new ImdbSentimentIterator(DATA_PATH, vocab, 50, 150, true);
ImdbSentimentIterator testIter = new ImdbSentimentIterator(DATA_PATH, vocab, 50, 150, false);

INDArray table = Nd4j.zeros(vocab.numWords(), wordVectors.getWordVector(wordVectors.vocab().wordAtIndex(0)).length);
vocab.words().stream().forEach(x -> table.putRow(vocab.indexOf((String)x)+1, Nd4j.create(wordVectors.getWordVector((String)x))));

// define networks layers
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(0.0001))
.l2(0.0001)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
.graphBuilder()
.addInputs("input")
.setInputTypes(InputType.recurrent(150))
.addLayer("embeddingEncoder",
new FrozenLayer.Builder().layer(
new EmbeddingLayer.Builder()
.nIn(vocab.numWords())
.nOut(embeddingSize)
.activation(Activation.IDENTITY)
.biasInit(0.0)
.build()
).build(),
"input")
.addLayer("lstm",
new LSTM.Builder()
.weightInitRecurrent(WeightInit.XAVIER)
.nIn(embeddingSize)
.nOut(300)
.activation(Activation.TANH)
.build(),
"embeddingEncoder")
.addVertex("last", new LastTimeStepVertex("input"), "lstm")
.addLayer("dense1",
new DenseLayer.Builder()
.nIn(300)
.nOut(100)
.activation(Activation.LEAKYRELU)
.build(),
"last")
.addLayer("bn1", new BatchNormalization.Builder().build(), "dense1")
.addLayer("output",
new OutputLayer.Builder()
.nIn(100)
.nOut(2)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.build(),
"bn1")
.setOutputs("output")
.build();

ComputationGraph model = new ComputationGraph(conf);
model.init();
model.getLayer("embeddingEncoder").setParam("W", table);

log.info(model.summary());

// define score and evaluation listener for training
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
model.setListeners(new StatsListener(statsStorage),
new ScoreIterationListener(10)
);

System.out.println("Starting training...");
model.fit(trainIter, 2);

System.out.println("Evaluating...");
Evaluation eval = model.evaluate(testIter);
System.out.println(eval.stats());
}

public static void trainingWord2Vec(){

File dataDir = new File(DATA_PATH+"/aclImdb/train/");
SentenceIterator iter = new FileSentenceIterator(dataDir);
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

log.info("Building model....");
Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(7)
.epochs(3)
.layerSize(embeddingSize)
.seed(42)
.windowSize(5)
.iterate(iter)
.tokenizerFactory(t)
.build();

log.info("Fitting Word2Vec model....");
vec.fit();

log.info("Training Finished");

// saved for future use.
WordVectorSerializer.writeWord2VecModel(vec, WORD2VEC_PATH);
log.info("Model Saved");
}


public static void downloadData() throws Exception {
//Create directory if required
File directory = new File(DATA_PATH);
if (!directory.exists()) directory.mkdir();

//Download file:
String archizePath = DATA_PATH + "aclImdb_v1.tar.gz";
File archiveFile = new File(archizePath);
String extractedPath = DATA_PATH + "aclImdb";
File extractedFile = new File(extractedPath);

if (!archiveFile.exists()) {
System.out.println("Starting data download (80MB)...");
FileUtils.copyURLToFile(new URL(DATA_URL), archiveFile);
System.out.println("Data (.tar.gz file) downloaded to " + archiveFile.getAbsolutePath());
//Extract tar.gz file to output directory
DataUtilities.extractTarGz(archizePath, DATA_PATH);
} else {
//Assume if archive (.tar.gz) exists, then data has already been extracted
System.out.println("Data (.tar.gz file) already exists at " + archiveFile.getAbsolutePath());
if (!extractedFile.exists()) {
//Extract tar.gz file to output directory
DataUtilities.extractTarGz(archizePath, DATA_PATH);
} else {
System.out.println("Data (extracted) already exists at " + extractedFile.getAbsolutePath());
}
}

List<String> fileToDelete = Arrays.asList("labeledBow.feat", "unsupBow.feat", "urls_pos.txt",
"urls_neg.txt", "urls_unsup.txt");

fileToDelete.forEach(f -> new File(extractedPath + "/train/" + f).delete());
}
}