-
Notifications
You must be signed in to change notification settings - Fork 153
/
DecisionVariableElimination.scala
274 lines (233 loc) · 11.9 KB
/
DecisionVariableElimination.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
/*
* DecisionVariableElimination.scala
* Variable elimination for Decisions algorithm.
*
* Created By: Brian Ruttenberg (bruttenberg@cra.com)
* Creation Date: Oct 1, 2012
*
* Copyright 2013 Avrom J. Pfeffer and Charles River Analytics, Inc.
* See http://www.cra.com or email figaro@cra.com for information.
*
* See http://www.github.com/p2t2/figaro for a copy of the software license.
*/
package com.cra.figaro.algorithm.decision
import com.cra.figaro.algorithm._
import com.cra.figaro.algorithm.factored._
import com.cra.figaro.algorithm.factored.factors._
import com.cra.figaro.algorithm.sampling._
import com.cra.figaro.language._
import com.cra.figaro.library.decision._
import com.cra.figaro.util._
import com.cra.figaro.algorithm.lazyfactored.Extended
import annotation.tailrec
import scala.collection.mutable.{ Map, Set }
import scala.language.existentials
import com.cra.figaro.algorithm.factored.factors.factory.Factory
/* Trait only extends for double utilities. User needs to provide another trait or convert utilities to double
* in order to use
*
*/
/**
* Trait for Decision based Variable Elimination. This implementation is hardcoded to use.
* Double utilities.
*/
trait ProbabilisticVariableEliminationDecision extends VariableElimination[(Double, Double)] {
/** Retrieve utility nodes in the model
*/
/* Implementations must define this */
def getUtilityNodes: List[Element[_]]
/**
* Semiring for Decisions uses a sum-product-utility semiring.
*/
override val semiring = SumProductUtilitySemiring()
/**
* Makes a utility factor an element designated as a utility. This is factor of a tuple (Double, Double)
* where the first value is 1.0 and the second is a possible utility of the element.
*/
def makeUtilFactor(e: Element[_]): Factor[(Double, Double)] = {
val f = Factory.defaultFactor[(Double, Double)](List(), List(Variable(e)), semiring)
f.fillByRule((l: List[Any]) => (1.0, l.asInstanceOf[List[Extended[Double]]](0).value))
f
}
/* Even though utility nodes are eliminated, we need to create factors for them and anything they use. */
override def starterElements = getUtilityNodes ::: targetElements
/**
* Create the factors for decision factors. Each factor is hardcoded as a tuple of (Double, Double),
* where the first value is the probability and the second is the utility.
*/
def getFactors(neededElements: List[Element[_]], targetElements: List[Element[_]], upper: Boolean = false): List[Factor[(Double, Double)]] = {
if (debug) {
println("Elements (other than utilities) appearing in factors and their ranges:")
for { element <- neededElements } {
println(Variable(element).id + "(" + element.name.string + "@" + element.hashCode + ")" + ": " + element + ": " + Variable(element).range.mkString(","))
}
}
val thisUniverseFactorsExceptUtil = neededElements flatMap (Factory.makeFactorsForElement(_))
// Make special utility factors for utility elements
val thisUniverseFactorsUtil = getUtilityNodes map (makeUtilFactor(_))
val dependentUniverseFactors =
for { (dependentUniverse, evidence) <- dependentUniverses } yield Factory.makeDependentFactor(Variable.cc, universe, dependentUniverse, dependentAlgorithm(dependentUniverse, evidence))
// Convert all non-utility factors from standard factors to decision factors, ie, factors are now tuples of (Double, _)
val thisUniverseFactorsExceptUtil_conv = thisUniverseFactorsExceptUtil.map(s => convert(s, false))
val thisUniverseFactorsUtil_conv = thisUniverseFactorsUtil
val dependentUniverseFactors_conv = dependentUniverseFactors.map(s => convert(s, false))
dependentUniverseFactors_conv ::: thisUniverseFactorsExceptUtil_conv ::: thisUniverseFactorsUtil_conv
}
/*
* Converts a factor created by ProbFactor into a tuple of (Prob, E[Utility]), where E[Utility] is zero for
* all non-utility nodes, and Prob is 1 for all utility nodes
*/
private def convert(f: Factor[Double], utility: Boolean): Factor[(Double, Double)] = {
val factor = Factory.defaultFactor[(Double, Double)](f.parents, f.output, semiring)
val allIndices = f.getIndices
if (!utility) {
f.mapTo[(Double, Double)]((d: Double) => (d, 0.0), semiring.asInstanceOf[Semiring[(Double, Double)]])
} else {
if (f.variables.length > 1) throw new IllegalUtilityNodeException
val newF = f.mapTo[(Double, Double)]((d: Double) => (d, 0.0), semiring.asInstanceOf[Semiring[(Double, Double)]])
for {i <- 0 until f.variables(0).range.size} {
newF.set(List(i), (newF.get(List(i))._1, f.variables(0).range(i).asInstanceOf[Double]))
}
newF
}
}
}
/**
* Decision VariableElimination algorithm that computes the expected utility of decision elements using the default
* elimination order.
*/
class ProbQueryVariableEliminationDecision[T, U](override val universe: Universe, utilityNodes: List[Element[_]], target: Element[_])(
val showTiming: Boolean,
val dependentUniverses: List[(Universe, List[NamedEvidence[_]])],
val dependentAlgorithm: (Universe, List[NamedEvidence[_]]) => () => Double)
extends OneTimeProbQuery
with ProbabilisticVariableEliminationDecision
with DecisionAlgorithm[T, U] {
lazy val queryTargets = List(target)
/**
* The variable elimination eliminates all variables except on all decision nodes and their parents.
* Thus the target elements is both the decision element and the parent element.
*/
val targetElements = List(target, target.args(0))
def getUtilityNodes = utilityNodes
private var finalFactors: Factor[(Double, Double)] = Factory.defaultFactor[(Double, Double)](List(), List(), semiring)
/* Marginalizes the final factor using the semiring for decisions
*
*/
private def marginalizeToTarget(factor: Factor[(Double, Double)], target: Element[_]): Unit = {
val unnormalizedTargetFactor = factor.marginalizeTo(Variable(target))
val z = unnormalizedTargetFactor.foldLeft(semiring.zero, (x: (Double, Double), y: (Double, Double)) => (x._1 + y._1, 0.0))
//val targetFactor = Factory.make[(Double, Double)](unnormalizedTargetFactor.variables)
val targetFactor = unnormalizedTargetFactor.mapTo((d: (Double, Double)) => (d._1 / z._1, d._2))
targetFactors += target -> targetFactor
}
private def marginalize(resultFactor: Factor[(Double, Double)]) =
queryTargets foreach (marginalizeToTarget(resultFactor, _))
private def makeResultFactor(factorsAfterElimination: MultiSet[Factor[(Double, Double)]]): Factor[(Double, Double)] = {
// It is possible that there are no factors (this will happen if there are no decisions or utilities).
// Therefore, we start with the unit factor and use foldLeft, instead of simply reducing the factorsAfterElimination.
factorsAfterElimination.foldLeft(Factory.unit(semiring))(_.product(_))
}
def finish(factorsAfterElimination: MultiSet[Factor[(Double, Double)]], eliminationOrder: List[Variable[_]]) =
finalFactors = makeResultFactor(factorsAfterElimination)
/**
* Returns distribution of the target, ignoring utilities.
*/
def computeDistribution[T](target: Element[T]): Stream[(Double, T)] = {
val factor = targetFactors(target)
val targetVar = Variable(target)
val dist = targetVar.range.filter(_.isRegular).map(_.value).zipWithIndex map (pair => (factor.get(List(pair._2))._1, pair._1))
// normalization is unnecessary here because it is done in marginalizeTo
dist.toStream
}
/**
* Returns expectation of the target, ignoring utilities
*/
def computeExpectation[T](target: Element[T], function: T => Double): Double = {
def get(pair: (Double, T)) = pair._1 * function(pair._2)
(0.0 /: computeDistribution(target))(_ + get(_))
}
/**
* Returns the computed utility of all parent/decision tuple values. For VE, these are not samples
* but the actual computed expected utility for all combinations of the parent and decision.
*/
def computeUtility(): scala.collection.immutable.Map[(T, U), DecisionSample] = computeStrategy(finalFactors)
/*
* Converts the final factor into a map of parent/decision values and expected utility
*/
private def computeStrategy(factor: Factor[(Double, Double)]) = {
val strat = Map[(T, U), DecisionSample]()
//find the variable associated with the decision
val decisionVariable = factor.variables.filter(_.asInstanceOf[ElementVariable[_]].element == target)(0)
// find the variables of the parents.
val parentVariable = factor.variables.filterNot(_ == decisionVariable)(0)
// index of the decision variable
val indexOfDecision = indices(factor.variables, decisionVariable)
val indexOfParent = indices(factor.variables, parentVariable)
for { indices <- factor.getIndices} {
/* for each index in the list of indices, strip out the decision variable index,
* and retrieve the map entry for the parents. If the factor value is greater than
* what is currently stored in the strategy map, replace the decision with the new one from the factor
*/
val parent = parentVariable.range(indices(indexOfParent(0))).value.asInstanceOf[T]
val decision = decisionVariable.range(indices(indexOfDecision(0))).value.asInstanceOf[U]
val utility = factor.get(indices)._2
strat += (parent, decision) -> DecisionSample(utility, 1.0)
}
strat.toMap
}
}
object DecisionVariableElimination {
/* Checks conditions of Decision Usage
* 1. Double utilities
*/
private[decision] def usageCheck(utilityNodes: List[Element[_]], target: Decision[_, _]): Unit = {
utilityNodes.foreach { u =>
u.value match {
case d: Double => 1
case _ => throw new IllegalArgumentException("Only double utilities are allowed")
}
}
}
/**
* Create a decision variable elimination instance with the given decision variables and indicated utility
* nodes.
*/
def apply[T, U](utilityNodes: List[Element[_]], target: Decision[T, U])(implicit universe: Universe) = {
utilityNodes.foreach(_.generate()) // need initial values for the utility nodes before the usage check
usageCheck(utilityNodes, target)
new ProbQueryVariableEliminationDecision[T, U](universe, utilityNodes, target)(
false,
List(),
(u: Universe, e: List[NamedEvidence[_]]) => () => ProbEvidenceSampler.computeProbEvidence(10000, e)(u))
}
/**
* Create a decision variable elimination algorithm with the given decision variables and indicated utility
* nodes and using the given dependent universes in the current default universe.
*/
def apply[T, U](dependentUniverses: List[(Universe, List[NamedEvidence[_]])], utilityNodes: List[Element[_]], target: Decision[T, U])(implicit universe: Universe) = {
utilityNodes.foreach(_.generate()) // need initial values for the utility nodes before the usage check
usageCheck(utilityNodes, target)
new ProbQueryVariableEliminationDecision[T, U](universe, utilityNodes, target)(
false,
dependentUniverses,
(u: Universe, e: List[NamedEvidence[_]]) => () => ProbEvidenceSampler.computeProbEvidence(10000, e)(u))
}
/**
* Create a decision variable elimination algorithm with the given decision variables and indicated utility
* nodes and using the given dependent universes in the current default universe. Use the given dependent
* algorithm function to determine the algorithm to use to compute probability of evidence in each dependent universe.
*/
def apply[T, U](
dependentUniverses: List[(Universe, List[NamedEvidence[_]])],
dependentAlgorithm: (Universe, List[NamedEvidence[_]]) => () => Double,
utilityNodes: List[Element[_]],
target: Decision[T, U])(implicit universe: Universe) = {
utilityNodes.foreach(_.generate()) // need initial values for the utility nodes before the usage check
usageCheck(utilityNodes, target)
new ProbQueryVariableEliminationDecision[T, U](universe, utilityNodes, target)(
false,
dependentUniverses,
dependentAlgorithm)
}
}