Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-threaded prediction for treelite4j #5

Open
thvasilo opened this issue Aug 25, 2022 · 2 comments
Open

Multi-threaded prediction for treelite4j #5

thvasilo opened this issue Aug 25, 2022 · 2 comments

Comments

@thvasilo
Copy link

thvasilo commented Aug 25, 2022

Hello,

I'm running some benchmarks for treelite4j, testing out different batch sizes (splitting up a dataset into batches and predicting for each batch in sequence) and the number of threads passed to the Predictor object.

One thing I'm observing is that the number of threads set in the Predictor only seems to matter when my batch size is larger than 1, i.e. if I create a DMatrix with only a single row and call Predict on it, the number of threads the Predictor object was created with doesn't seem to matter.

Also, batch size doesn't seem to have a large effect when prediction is single threaded, is that expected as well?

Is it the case that multi-threading is only relevant when there's more than one row in the input DMatrix?

Would it be possible to use multi-threading for single-instance prediction as well, using each thread to predict for a single tree and merging the result in the end?

JMH results:

Benchmark           (batchSize)  (datapointNumber)    (treeliteThreads)   Mode  Cnt  Score   Error  Units
treelitePrediction            1             100000                    1  thrpt    3  0.053 ± 0.009  ops/s
treelitePrediction            1             100000                    8  thrpt    3  0.053 ± 0.005  ops/s
treelitePrediction           10             100000                    1  thrpt    3  0.060 ± 0.007  ops/s
treelitePrediction           10             100000                    8  thrpt    3  0.156 ± 0.515  ops/s
treelitePrediction          100             100000                    1  thrpt    3  0.064 ± 0.016  ops/s
treelitePrediction          100             100000                    8  thrpt    3  0.228 ± 0.667  ops/s

Some example code:

package me.tvas.benchmark

import java.io.IOException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.stream.DoubleStream;
import java.util.stream.Collectors;
import java.util.Iterator;

import ml.dmlc.treelite4j.java.*;
import ml.dmlc.treelite4j.DataPoint;
import ml.dmlc.treelite4j.DataPointFloat64;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.rng.Random;

import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.results.format.ResultFormatType;


@State(Scope.Benchmark)
public class PredictionBenchmarks {

    @Param({""})
    public String treeliteModelPath;

    @Param({"1000"})
    public int datapointNumber;

    @Param({"1"})
    public int batchSize;

    @Param({"1"})
    public int treeliteThreads;

    @Param({"/tmp"})
    public static String destinationFolder;


    Predictor treelitePredictor;
    DataSet randomDataSet;
    long numFeature;
    int numBatches;

    @Setup
    public void prepare() throws IOException, TreeliteError {
        Predictor treelitePredictor = new Predictor(this.treeliteModelPath, this.treeliteThreads, false);
        // Create random data for prediction.
        long numFeature = treelitePredictor.GetNumFeature();
        long rngSeed = 42;
        long[] shape = new long[]{datapointNumber, numFeature};
        Nd4j.getRandom().setSeed(rngSeed);;
        INDArray randomDoubles = Nd4j.rand(0.0, 32000.0, Nd4j.getRandom(), shape);
        INDArray dummyLabels = Nd4j.ones(new long[]{datapointNumber, 1});


        this.treelitePredictor = treelitePredictor;
        this.randomDataSet = new DataSet(randomDoubles, dummyLabels);
        this.numFeature = numFeature;
    }

    @Benchmark
    public void treelitePrediction(Blackhole blackhole) throws TreeliteError {
        List<Double> treelitePreds = new ArrayList<Double>(datapointNumber);
        List<DataSet> batches = this.randomDataSet.batchBy(batchSize);
        Iterator<DataSet> datasetIterator = batches.iterator();

        while(datasetIterator.hasNext()) {
            DataSet batch = datasetIterator.next();
            INDArray features = batch.getFeatures();
            long currentBatchSize = features.shape()[0];

            double[] doubleVector = features.data().asDouble();
            DMatrix dmat = new DMatrix(doubleVector, Double.NaN, currentBatchSize, this.numFeature);
            INDArray preds = treelitePredictor.predict(dmat, false, false);
            blackhole.consume(preds);
        }
    }


    public static void main(String[] args) throws Exception {

        Options opt = new OptionsBuilder()
                .include(PredictionBenchmarks.class.getSimpleName())
                .result(destinationFolder + "/" + "benchmarkResults.csv")
                .resultFormat(ResultFormatType.CSV)
                .forks(1)
                .threads(1)
                .jvmArgs("-ea")
                .build();

        new Runner(opt).run();
    }
}
@hcho3
Copy link
Collaborator

hcho3 commented Aug 30, 2022

Is it the case that multi-threading is only relevant when there's more than one row in the input DMatrix?

Yes, currently multi-threading is only useful when you have multiple rows in the input DMatrix. The rows of the DMatrix get distributed equally across worker threads.

@guozhaochen
Copy link

I am trying to call the predict in a multi-threading way (i.e., multiple threads calling the predict instead of multiple worker threads in the predictor), so I set the thread to 1 so threads are not blocked by the synchronization. However, I found out that the JavaCPP library used by the ND4J doesn't allow multi-threading as well, see here https://github.com/bytedeco/javacpp/blob/d23879af7a03a04c12b2374ae9d0850b9dda9d96/src/main/java/org/bytedeco/javacpp/Pointer.java#L699

Any particular reason that we need to use INDArray from ND4J?

@hcho3 hcho3 transferred this issue from dmlc/treelite May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants