-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict_value.py
56 lines (39 loc) · 1.43 KB
/
predict_value.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
'''
This script predicts the value function at initial time on a grid.
'''
import numpy as np
import scipy.io
import time
import sys
from utilities.other import int_input, load_NN
from examples.choose_problem import system, problem, config, time_dependent
if time_dependent:
from utilities.neural_networks import HJBnet
system += '/tspan'
else:
from utilities.neural_networks import HJBnet_t0 as HJBnet
system += '/t0'
parameters, scaling = load_NN('examples/' + system + '/V_model.mat')
model = HJBnet(problem, scaling, config, parameters)
# ---------------------------------------------------------------------------- #
pred_time = time.time()
ub = problem.X0_ub
lb = problem.X0_lb
plotdims = config.plotdims
Nm = [100,100]
Nout = np.prod(Nm)
# Plots mean value of x if not part of plotdims
X = np.tile((ub + lb)/2., (1, Nout))
# Makes a meshgrid out of plotdims
X_mesh = []
for d in range(len(plotdims)):
X_mesh.append(np.linspace(lb[plotdims[d]], ub[plotdims[d]], Nm[d]))
X_mesh = np.meshgrid(*X_mesh)
for d in range(len(plotdims)):
X[plotdims[d],:] = X_mesh[d].flatten()
V = model.predict_V(np.zeros((1, Nout)), X).reshape(Nm)
pred_time = time.time() - pred_time
print('Prediction time: %.1f' % (pred_time))
save_dict = {'plotdims': np.array(plotdims)+1, 'X': X_mesh, 'V': V,
'U': model.eval_U(np.zeros((1, Nout)), X)[0].reshape(Nm)}
scipy.io.savemat('examples/' + system + '/results/val_pred.mat', save_dict)