Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with Informer while using ProbAttention #47

Open
Nodon447 opened this issue Oct 9, 2023 · 3 comments
Open

Problem with Informer while using ProbAttention #47

Nodon447 opened this issue Oct 9, 2023 · 3 comments

Comments

@Nodon447
Copy link

Nodon447 commented Oct 9, 2023

So i wanted to use the Informer model to predict time series.
but even the Example doesnt work for me when setting prob_attention True

so this is my code:

params: Dict[str, Any] = {
"n_encoder_layers": 1,
"n_decoder_layers": 1,
"attention_hidden_sizes": 32 * 1,
"num_heads": 1,
"attention_dropout": 0.0,
"ffn_hidden_sizes": 32 * 1,
"ffn_filter_sizes": 32 * 1,
"ffn_dropout": 0.0,
"skip_connect_circle": False,
"skip_connect_mean": False,
"prob_attention": False,
"distil_conv": False,
}

custom_params = params.copy()
custom_params["prob_attention"] = True

option1: np.ndarray

train_length = 49
predict_length = 10
n_encoder_feature = 2
n_decoder_feature = 3

x_train = (
np.random.rand(1, train_length, 1), # inputs: (batch, train_length, 1)
np.random.rand(1, train_length, n_encoder_feature), # encoder_feature: (batch, train_length, encoder_features)
np.random.rand(1, predict_length, n_decoder_feature), # decoder_feature: (batch, predict_length, decoder_features)
)
y_train = np.random.rand(1, predict_length, 1) # target: (batch, predict_length, 1)

x_valid = (
np.random.rand(1, train_length, 1),
np.random.rand(1, train_length, n_encoder_feature),
np.random.rand(1, predict_length, n_decoder_feature),
)
y_valid = np.random.rand(1, predict_length, 1)

model = AutoModel("Informer", predict_length=predict_length,custom_model_params=custom_params)
trainer = KerasTrainer(model)
trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1)

and this is the error:

TypeError Traceback (most recent call last)
Cell In[9], line 45
43 model = AutoModel("Informer", predict_length=predict_length,custom_model_params=custom_params)
44 trainer = KerasTrainer(model)
---> 45 trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1)

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/trainer.py:289, in KerasTrainer.train(self, train_dataset, valid_dataset, n_epochs, batch_size, steps_per_epoch, callback_eval_metrics, early_stopping, checkpoint, verbose, **kwargs)
286 else:
287 raise ValueError("tfts inputs should be either tf.data instance or 3d array list/tuple")
--> 289 self.model = self.model.build_model(inputs=inputs)
291 # print(self.model.summary())
292 self.model.compile(loss=self.loss_fn, optimizer=self.optimizer, metrics=callback_eval_metrics, run_eagerly=True)

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/auto_model.py:81, in AutoModel.build_model(self, inputs)
80 def build_model(self, inputs):
---> 81 outputs = self.model(inputs)
82 return tf.keras.Model([inputs], [outputs])

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/informer.py:120, in Informer.call(self, inputs, teacher)
115 decoder_feature = tf.cast(
116 tf.reshape(tf.range(self.predict_sequence_length), (-1, self.predict_sequence_length, 1)), tf.float32
117 )
119 encoder_feature = self.encoder_embedding(encoder_feature) # batch * seq * embedding_size
--> 120 memory = self.encoder(encoder_feature, mask=None)
122 B, L, _ = tf.shape(decoder_feature)
123 casual_mask = CausalMask(B * self.params["num_heads"], L).mask

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback..error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.traceback)
68 # To get the full stack trace, call:
69 # tf.debugging.disable_traceback_filtering()
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb

File /tmp/autograph_generated_filewsa2jpfz.py:56, in outer_factory..inner_factory..tf__call(self, x, mask)
54 conv_layer = ag
.Undefined('conv_layer')
55 attn_layer = ag__.Undefined('attn_layer')
---> 56 ag__.if_stmt((ag__.ld(self).conv_layers is not None), if_body, else_body, get_state_2, set_state_2, ('x',), 1)
58 def get_state_3():
59 return (x,)

File /tmp/autograph_generated_filewsa2jpfz.py:36, in outer_factory..inner_factory..tf__call..if_body()
34 attn_layer = ag
.Undefined('attn_layer')
35 ag__.for_stmt(ag__.converted_call(ag__.ld(zip), (ag__.ld(self).layers, ag__.ld(self).conv_layers), None, fscope), None, loop_body, get_state, set_state, ('x',), {'iterate_names': '(attn_layer, conv_layer)'})
---> 36 x = ag__.converted_call(ag__.ld(self).layers[(- 1)], (ag__.ld(x), ag__.ld(mask)), None, fscope)

File /tmp/autograph_generated_file32nu44x0.py:12, in outer_factory..inner_factory..tf__call(self, x, mask)
10 retval
= ag
_.UndefinedReturnValue()
11 input = ag__.ld(x)
---> 12 x = ag__.converted_call(ag__.ld(self).attn_layer, (ag__.ld(x), ag__.ld(x), ag__.ld(x), ag__.ld(mask)), None, fscope)
13 x = ag__.converted_call(ag__.ld(self).drop, (ag__.ld(x),), None, fscope)
14 x = (ag__.ld(x) + ag__.ld(input))

File /tmp/autograph_generated_file6otuhk1u.py:16, in outer_factory..inner_factory..tf__call(self, q, k, v, mask)
14 (B, L, D) = ag
.ld(q).shape
15 (, S, ) = ag_.ld(k).shape
---> 16 q
= ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(q), (ag__.ld(B), ag__.ld(self).num_heads, ag__.ld(L), (- 1))), None, fscope)
17 k_ = ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(k), (ag__.ld(B), ag__.ld(self).num_heads, ag__.ld(S), (- 1))), None, fscope)
18 v_ = ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(v), (ag__.ld(B), ag__.ld(self).num_heads, ag__.ld(S), (- 1))), None, fscope)

TypeError: Exception encountered when calling layer "encoder_4" (type Encoder).

in user code:

File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/informer.py", line 153, in call  *
    x = self.layers[-1](x, mask)
File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
    raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file32nu44x0.py", line 12, in tf__call
    x = ag__.converted_call(ag__.ld(self).attn_layer, (ag__.ld(x), ag__.ld(x), ag__.ld(x), ag__.ld(mask)), None, fscope)
File "/tmp/__autograph_generated_file6otuhk1u.py", line 16, in tf__call
    q_ = ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(q), (ag__.ld(B), ag__.ld(self).num_heads, ag__.ld(L), (- 1))), None, fscope)

TypeError: Exception encountered when calling layer 'encoder_layer_4' (type EncoderLayer).

in user code:

    File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/informer.py", line 183, in call  *
        x = self.attn_layer(x, x, x, mask)
    File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_file6otuhk1u.py", line 16, in tf__call
        q_ = ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(q), (ag__.ld(B), ag__.ld(self).num_heads, ag__.ld(L), (- 1))), None, fscope)

    TypeError: Exception encountered when calling layer 'prob_attention_8' (type ProbAttention).
    
    in user code:
    
        File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/layers/attention_layer.py", line 203, in call  *
            q_ = tf.reshape(q, (B, self.num_heads, L, -1))
    
        TypeError: Failed to convert elements of (None, 1, 49, -1) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
    
    
    Call arguments received by layer 'prob_attention_8' (type ProbAttention):
      • q=tf.Tensor(shape=(None, 49, 32), dtype=float32)
      • k=tf.Tensor(shape=(None, 49, 32), dtype=float32)
      • v=tf.Tensor(shape=(None, 49, 32), dtype=float32)
      • mask=None


Call arguments received by layer 'encoder_layer_4' (type EncoderLayer):
  • x=tf.Tensor(shape=(None, 49, 32), dtype=float32)
  • mask=None

Call arguments received by layer "encoder_4" (type Encoder):
• x=tf.Tensor(shape=(None, 49, 32), dtype=float32)
• mask=None

I have no clue how to fix this. So it would be really nice if anyone could help.

LongxingTan added a commit that referenced this issue Oct 11, 2023
* Update dataembed from tokenembed

* Update dataembed from tokenembed

* fix: modify the shape for graph mode #47

* fix: solve conflict of informer
@LongxingTan
Copy link
Owner

Hi @Nodon447 ,

You can try the latest version now.
This is due to the batch size in tensorflow static graph mode is None. The prob attention need the batch size as a number.

Thanks

@Nodon447
Copy link
Author

Hi @LongxingTan
Thanks for your help and time! :) its working now for me as long as Batch Size is 1. But when the Batch size is increased i still get an error. Do you know whats needed to be changed?

train_length = 49
predict_length = 10
n_encoder_feature = 2
n_decoder_feature = 3

x_train = (
np.random.rand(2, train_length, 1), # inputs: (batch, train_length, 1)
np.random.rand(2, train_length, n_encoder_feature), # encoder_feature: (batch, train_length, encoder_features)
np.random.rand(2, predict_length, n_decoder_feature), # decoder_feature: (batch, predict_length, decoder_features)
)
y_train = np.random.rand(2, predict_length, 1) # target: (batch, predict_length, 1)

x_valid = (
np.random.rand(2, train_length, 1),
np.random.rand(2, train_length, n_encoder_feature),
np.random.rand(2, predict_length, n_decoder_feature),
)
y_valid = np.random.rand(2, predict_length, 1)

model = AutoModel("Informer", predict_length=predict_length,custom_model_params=custom_params)
trainer = KerasTrainer(model)
trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1)


InvalidArgumentError Traceback (most recent call last)
Cell In[8], line 22
20 model = AutoModel("Informer", predict_length=predict_length,custom_model_params=custom_params)
21 trainer = KerasTrainer(model)
---> 22 trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1)

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/trainer.py:310, in KerasTrainer.train(self, train_dataset, valid_dataset, n_epochs, batch_size, steps_per_epoch, callback_metrics, early_stopping, checkpoint, verbose, **kwargs)
307 if isinstance(train_dataset, (list, tuple)):
308 x_train, y_train = train_dataset
--> 310 self.history = self.model.fit(
311 x_train,
312 y_train,
313 validation_data=valid_dataset,
314 steps_per_epoch=steps_per_epoch,
315 epochs=n_epochs,
316 batch_size=batch_size,
317 verbose=verbose,
318 callbacks=callbacks,
319 )
320 else:
321 self.history = self.model.fit(
322 train_dataset,
323 validation_data=valid_dataset,
(...)
328 callbacks=callbacks,
329 )

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback..error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.traceback)
68 # To get the full stack trace, call:
69 # tf.debugging.disable_traceback_filtering()
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/informer.py:151, in Encoder.call(self, x, mask)
149 x = attn_layer(x, mask)
150 # x = conv_layer(x)
--> 151 x = self.layers[-1](x, mask)
153 else:
154 for attn_layer in self.layers:

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/models/informer.py:181, in EncoderLayer.call(self, x, mask)
179 """Informer encoder layer call function"""
180 input = x
--> 181 x = self.attn_layer(x, x, x, mask)
182 x = self.drop(x)
183 x = x + input

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/layers/attention_layer.py:218, in ProbAttention.call(self, q, k, v, mask)
215 u_q = u_q if u_q < L else L
216 u_k = u_k if u_k < S else S
--> 218 scores_top, index = self.prob_qk(q, k_, u_k, u_q)
219 scores_top = scores_top * 1.0 / np.sqrt(D // self.num_heads)
221 context = self.get_initial_context(v, L)

File /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tfts/layers/attention_layer.py:166, in ProbAttention._prob_qk(self, q, k, sample_k, top_n)
163 batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n))
164 head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n))
--> 166 idx = tf.stack([batch_indexes, head_indexes, m_top], axis=-1)
168 q_reduce = tf.gather_nd(q, idx)
169 qk = tf.matmul(q_reduce, tf.transpose(k, (0, 1, 3, 2)))

InvalidArgumentError: Exception encountered when calling layer 'prob_attention' (type ProbAttention).

{{function_node _wrapped__Pack_N_3_device/job:localhost/replica:0/task:0/device:CPU:0}} Shapes of all inputs must match: values[0].shape = [2,1,20] != values[2].shape = [2,20] [Op:Pack] name: stack

Call arguments received by layer 'prob_attention' (type ProbAttention):
• q=tf.Tensor(shape=(2, 49, 32), dtype=float32)
• k=tf.Tensor(shape=(2, 49, 32), dtype=float32)
• v=tf.Tensor(shape=(2, 49, 32), dtype=float32)
• mask=None

@LongxingTan
Copy link
Owner

@Nodon447

I have fixed it in the latest version, try 0.0.10 version

The squeeze should assign the specific dimension, as below. If not, the batch dimension will also be removed

Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)))
Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)), axis=3)

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants