/
inference.py
144 lines (107 loc) · 5.02 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Code for running inference with transformer
"""
import torch.nn as nn
import torch
import utils
def run_encoder_decoder_inference(
model: nn.Module,
src: torch.Tensor,
forecast_window: int,
batch_size: int,
device,
batch_first: bool=False
) -> torch.Tensor:
"""
NB! This function is currently only tested on models that work with
batch_first = False
This function is for encoder-decoder type models in which the decoder requires
an input, tgt, which - during training - is the target sequence. During inference,
the values of tgt are unknown, and the values therefore have to be generated
iteratively.
This function returns a prediction of length forecast_window for each batch in src
NB! If you want the inference to be done without gradient calculation,
make sure to call this function inside the context manager torch.no_grad like:
with torch.no_grad:
run_encoder_decoder_inference()
The context manager is intentionally not called inside this function to make
it usable in cases where the function is used to compute loss that must be
backpropagated during training and gradient calculation hence is required.
If use_predicted_tgt = True:
To begin with, tgt is equal to the last value of src. Then, the last element
in the model's prediction is iteratively concatenated with tgt, such that
at each step in the for-loop, tgt's size increases by 1. Finally, tgt will
have the correct length (target sequence length) and the final prediction
will be produced and returned.
Args:
model: An encoder-decoder type model where the decoder requires
target values as input. Should be set to evaluation mode before
passed to this function.
src: The input to the model
forecast_horizon: The desired length of the model's output, e.g. 58 if you
want to predict the next 58 hours of FCR prices.
batch_size: batch size
batch_first: If true, the shape of the model input should be
[batch size, input sequence length, number of features].
If false, [input sequence length, batch size, number of features]
"""
# Dimension of a batched model input that contains the target sequence values
target_seq_dim = 0 if batch_first == False else 1
# Take the last value of thetarget variable in all batches in src and make it tgt
# as per the Influenza paper
tgt = src[-1, :, 0] if batch_first == False else src[:, -1, 0] # shape [1, batch_size, 1]
# Change shape from [batch_size] to [1, batch_size, 1]
if batch_size == 1 and batch_first == False:
tgt = tgt.unsqueeze(0).unsqueeze(0) # change from [1] to [1, 1, 1]
# Change shape from [batch_size] to [1, batch_size, 1]
if batch_first == False and batch_size > 1:
tgt = tgt.unsqueeze(0).unsqueeze(-1)
# Iteratively concatenate tgt with the first element in the prediction
for _ in range(forecast_window-1):
# Create masks
dim_a = tgt.shape[1] if batch_first == True else tgt.shape[0]
dim_b = src.shape[1] if batch_first == True else src.shape[0]
tgt_mask = utils.generate_square_subsequent_mask(
dim1=dim_a,
dim2=dim_a,
device=device
)
src_mask = utils.generate_square_subsequent_mask(
dim1=dim_a,
dim2=dim_b,
device=device
)
# Make prediction
prediction = model(src, tgt, src_mask, tgt_mask)
# If statement simply makes sure that the predicted value is
# extracted and reshaped correctly
if batch_first == False:
# Obtain the predicted value at t+1 where t is the last time step
# represented in tgt
last_predicted_value = prediction[-1, :, :]
# Reshape from [batch_size, 1] --> [1, batch_size, 1]
last_predicted_value = last_predicted_value.unsqueeze(0)
else:
# Obtain predicted value
last_predicted_value = prediction[:, -1, :]
# Reshape from [batch_size, 1] --> [batch_size, 1, 1]
last_predicted_value = last_predicted_value.unsqueeze(-1)
# Detach the predicted element from the graph and concatenate with
# tgt in dimension 1 or 0
tgt = torch.cat((tgt, last_predicted_value.detach()), target_seq_dim)
# Create masks
dim_a = tgt.shape[1] if batch_first == True else tgt.shape[0]
dim_b = src.shape[1] if batch_first == True else src.shape[0]
tgt_mask = utils.generate_square_subsequent_mask(
dim1=dim_a,
dim2=dim_a,
device=device
)
src_mask = utils.generate_square_subsequent_mask(
dim1=dim_a,
dim2=dim_b,
device=device
)
# Make final prediction
final_prediction = model(src, tgt, src_mask, tgt_mask)
return final_prediction