Skip to content

Commit

Permalink
Small update in example - 10 repetitions (discard first 5)
Browse files Browse the repository at this point in the history
  • Loading branch information
anamf committed Nov 30, 2015
1 parent 6f98e3d commit 929677e
Showing 1 changed file with 44 additions and 19 deletions.
Expand Up @@ -14,7 +14,9 @@
import eu.amidst.dynamic.variables.DynamicVariables;

import java.io.IOException;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
Expand All @@ -35,6 +37,7 @@ public class DynamicIS_Scalability implements AmidstOptionsHandler {
boolean connectChildrenTemporally = false;
boolean activateMiddleLayer = true;
int seed = 1;
int numberOfSamples = 1000;


public int getNumOfSequences() {
Expand Down Expand Up @@ -109,6 +112,14 @@ public void setSeed(int seed) {
this.seed = seed;
}

public int getNumberOfSamples() {
return numberOfSamples;
}

public void setNumberOfSamples(int numberOfSamples) {
this.numberOfSamples = numberOfSamples;
}

public void runExperiment(){
Random random = new Random(this.getSeed());

Expand Down Expand Up @@ -197,6 +208,8 @@ public void runExperiment(){
DataStream<DynamicDataInstance> dataPredict = dynamicSampler.sampleToDataBase(this.getNumOfSequences(),
this.getSequenceLength());

List<DynamicDataInstance> dataPredictList = dataPredict.stream().collect(Collectors.toList());


//********************************************************************************************
// DYNAMIC IS WITH FACTORED FRONTIER ALGORITHM
Expand All @@ -206,6 +219,7 @@ public void runExperiment(){
ImportanceSampling importanceSampling = new ImportanceSampling();
importanceSampling.setParallelMode(true);
importanceSampling.setKeepDataOnMemory(true);
importanceSampling.setSampleSize(this.getNumberOfSamples());
FactoredFrontierForDBN factoredFrontierForDBN = new FactoredFrontierForDBN(importanceSampling);
InferenceEngineForDBN.setInferenceAlgorithmForDBN(factoredFrontierForDBN);
//Then, we set the DBN model
Expand All @@ -214,28 +228,37 @@ public void runExperiment(){
UnivariateDistribution posterior = null;
int time = 0 ;

long start = System.nanoTime();
for (DynamicDataInstance instance : dataPredict) {
//The InferenceEngineForDBN must be reset at the begining of each Sequence.
if (instance.getTimeID()==0 && posterior != null) {
InferenceEngineForDBN.reset();
time=0;
}
//We also set the evidence.
InferenceEngineForDBN.addDynamicEvidence(instance);
double average = 0;
for (int j = 0; j < 15; j++) {
long start = System.nanoTime();
for (DynamicDataInstance instance : dataPredictList) {
//The InferenceEngineForDBN must be reset at the begining of each Sequence.
if (instance.getTimeID() == 0 && posterior != null) {
InferenceEngineForDBN.reset();
time = 0;
}
factoredFrontierForDBN.setSeed(j);

//We also set the evidence.
InferenceEngineForDBN.addDynamicEvidence(instance);

//Then we run inference
InferenceEngineForDBN.runInference();
//Then we run inference
InferenceEngineForDBN.runInference();

//Then we query the posterior of the target variable
posterior = InferenceEngineForDBN.getFilteredPosterior(varH1);
//Then we query the posterior of the target variable
posterior = InferenceEngineForDBN.getFilteredPosterior(varH1);

//We show the output
System.out.println("P(varH1|e[0:"+(time++)+"]) = "+posterior);
//We show the output
//System.out.println("P(varH1|e[0:" + (time++) + "]) = " + posterior);
}
long duration = (System.nanoTime() - start) / 1;
double seconds = duration / 1000000000.0;
if (j > 4) {
average += seconds;
}
}
long duration = (System.nanoTime() - start) / 1;
double seconds = duration / 1000000000.0;
System.out.println("Time for Dynamic IS = "+seconds+" secs");

System.out.println("Time for Dynamic IS = "+average/10+" secs");
}

public static void main(String[] args) throws IOException {
Expand Down Expand Up @@ -265,7 +288,8 @@ public String listOptions() {
"-linkNodes, false, Connects leaf nodes in consecutive time steps.\\"+
"-activateMiddleLayer, true, Create middle layer with two (temporaly connected) " +
"discrete hidden nodes.\\"+
"-seed, 1, seed to generate random numbers\\";
"-seed, 1, seed to generate random numbers\\"+
"-samples, 1000, Number of samples for IS";
}

@Override
Expand All @@ -284,5 +308,6 @@ public void loadOptions() {
this.setConnectChildrenTemporally(getBooleanOption("-linkNodes"));
this.setActivateMiddleLayer(getBooleanOption("-activateMiddleLayer"));
this.setSeed(this.getIntOption("-seed"));
this.setNumberOfSamples(this.getIntOption("-samples"));
}
}

0 comments on commit 929677e

Please sign in to comment.