-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
ParallelInferenceExample.java
61 lines (49 loc) · 2.71 KB
/
ParallelInferenceExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/*******************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.examples.advanced.features.inference;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.File;
/**
* This examples shows use of ParallelInference mechanism
* Parallel Inference takes requests from multiple threads,
* collects them for a short while, and then queries the model for all collected requests.
* Since the model works in parallel internally, the available resources are still fully utilized.
*
* Refer to: https://www.dubs.tech/guides/quickstart-with-dl4j/#parallel-inference for more information
* @author raver119@gmail.com
*/
public class ParallelInferenceExample {
public static void main(String[] args) throws Exception {
// use path to your model here, or just instantiate it anywhere
MultiLayerNetwork model =MultiLayerNetwork.load(new File("PATH_TO_YOUR_MODEL_FILE"), false);
ParallelInference pi = new ParallelInference.Builder(model)
// BATCHED mode is kind of optimization: if number of incoming requests is too high - PI will be batching individual queries into single batch. If number of requests will be low - queries will be processed without batching
.inferenceMode(InferenceMode.BATCHED)
// max size of batch for BATCHED mode. you should set this value with respect to your environment (i.e. gpu memory amounts)
.batchLimit(32)
// set this value to number of available computational devices, either CPUs or GPUs
.workers(2)
.build();
// PLEASE NOTE: this output() call is just a placeholder, you should pass data in the same dimensionality you had during training
INDArray result = pi.output(new float[] {0.1f, 0.1f, 0.1f, 0.2f, 0,3f });
}
}