/
PassiveTDAgent.java
164 lines (154 loc) · 4.44 KB
/
PassiveTDAgent.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package aima.core.learning.reinforcement.agent;
import aima.core.learning.reinforcement.PerceptStateReward;
import aima.core.util.FrequencyCounter;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
/**
* Artificial Intelligence A Modern Approach (3rd Edition): page 837.<br>
* <br>
*
* <pre>
* function PASSIVE-TD-AGENT(percept) returns an action
* inputs: percept, a percept indicating the current state s' and reward signal r'
* persistent: π, a fixed policy
* U, a table of utilities, initially empty
* N<sub>s</sub>, a table of frequencies for states, initially zero
* s,a,r, the previous state, action, and reward, initially null
*
* if s' is new then U[s'] <- r'
* if s is not null then
* increment N<sub>s</sub>[s]
* U[s] <- U[s] + α(N<sub>s</sub>[s])(r + γU[s'] - U[s])
* if s'.TERMINAL? then s,a,r <- null else s,a,r <- s',π[s'],r'
* return a
* </pre>
*
* Figure 21.4 A passive reinforcement learning agent that learns utility
* estimates using temporal differences. The step-size function α(n) is
* chosen to ensure convergence, as described in the text.
*
* @param <S>
* the state type.
* @param <A>
* the action type.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
* @author Ruediger Lunde
*
*/
public class PassiveTDAgent<S, A> extends ReinforcementAgent<S, A> {
// persistent: π, a fixed policy
private Map<S, A> pi = new HashMap<>();
// U, a table of utilities, initially empty
private Map<S, Double> U = new HashMap<>();
// N<sub>s</sub>, a table of frequencies for states, initially zero
private FrequencyCounter<S> Ns = new FrequencyCounter<S>();
// s,a,r, the previous state, action, and reward, initially null
private S s = null;
private A a = null;
private Double r = null;
//
private double alpha = 0.0;
private double gamma = 0.0;
/**
* Constructor.
*
* @param fixedPolicy
* π a fixed policy.
* @param alpha
* a fixed learning rate.
* @param gamma
* discount to be used.
*/
public PassiveTDAgent(Map<S, A> fixedPolicy, double alpha, double gamma) {
this.pi.putAll(fixedPolicy);
this.alpha = alpha;
this.gamma = gamma;
}
/**
* Passive reinforcement learning that learns utility estimates using
* temporal differences
*
* @param percept
* a percept indicating the current state s' and reward signal
* r'.
* @return an action
*/
@Override
public Optional<A> act(PerceptStateReward<S> percept) {
// if s' is new then U[s'] <- r'
S sDelta = percept.state();
double rDelta = percept.reward();
if (!U.containsKey(sDelta)) {
U.put(sDelta, rDelta);
}
// if s is not null then
if (null != s) {
// increment N<sub>s</sub>[s]
Ns.incrementFor(s);
// U[s] <- U[s] + α(N<sub>s</sub>[s])(r + γU[s'] - U[s])
double U_s = U.get(s);
U.put(s, U_s + alpha(Ns, s) * (r + gamma * U.get(sDelta) - U_s));
}
// if s'.TERMINAL? then s,a,r <- null else s,a,r <- s',π[s'],r'
if (isTerminal(sDelta)) {
s = null;
a = null;
r = null;
} else {
s = sDelta;
a = pi.get(sDelta);
r = rDelta;
}
// return a
return Optional.ofNullable(a);
}
@Override
public Map<S, Double> getUtility() {
return new HashMap<S, Double>(U);
}
@Override
public void reset() {
U = new HashMap<>();
Ns.clear();
s = null;
a = null;
r = null;
}
//
// PROTECTED METHODS
//
/**
* AIMA3e pg. 836 'if we change α from a fixed parameter to a function
* that decreases as the number of times a state has been visited increases,
* then U<sup>π</sup>(s) itself will converge to the correct value.<br>
* <br>
* <b>Note:</b> override this method to obtain the desired behavior.
*
* @param Ns
* a frequency counter of observed states.
* @param s
* the current state.
* @return the learning rate to use based on the frequency of the state
* passed in.
*/
protected double alpha(FrequencyCounter<S> Ns, S s) {
// Default implementation is just to return a fixed parameter value
// irrespective of the # of times a state has been encountered
return alpha;
}
//
// PRIVATE METHODS
//
private boolean isTerminal(S s) {
boolean terminal = false;
A a = pi.get(s);
if (null == a) {
// No actions possible in state is considered terminal.
terminal = true;
}
return terminal;
}
}