Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't run haste layers in Keras #33

Open
mahenning opened this issue Jun 5, 2021 · 12 comments
Open

Can't run haste layers in Keras #33

mahenning opened this issue Jun 5, 2021 · 12 comments

Comments

@mahenning
Copy link

Hello,

I know this seems more of a debugging problem/problem on my side, but get the following error message when running my code, and it only appears when running it with a haste layer:

Traceback (most recent call last):
  File "<string>", line 1331, in haste_lstm
  File "<string>", line 1379, in haste_lstm_eager_fallback
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 280, in args_to_matching_eager
    ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 280, in <listcomp>
    ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/profiler/trace.py", line 163, in wrapped
    return func(*args, **kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1540, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 339, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 265, in constant
    allow_broadcast=True)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 276, in _constant_impl
    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 301, in _constant_eager_impl
    t = convert_to_eager_tensor(value, ctx, dtype)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/framework/constant_op.py", line 98, in convert_to_eager_tensor
    return ops.EagerTensor(value, ctx.device_name, dtype)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/engine/keras_tensor.py", line 274, in __array__
    'Cannot convert a symbolic Keras input/output to a numpy array. '
TypeError: Cannot convert a symbolic Keras input/output to a numpy array. This error may indicate that you're trying to pass a symbolic value to a NumPy call, which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register dispatching, preventing Keras from automatically converting the API call to a lambda layer in the Functional Model.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/snap/pycharm-professional/237/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/snap/pycharm-professional/237/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/time-series-on-joints-emg/src/all_in_one_file.py", line 394, in <module>
    x, state = haste1(x, training=True)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/haste_tf/base_rnn.py", line 115, in __call__
    result, state = self.fw_layer(inputs, sequence_length, training)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/haste_tf/lstm.py", line 218, in __call__
    zoneout_prob=self.zoneout)
  File "<string>", line 1339, in haste_lstm
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 122, in dispatch
    result = dispatcher.handle(op, args, kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py", line 1450, in handle
    return TFOpLambda(op)(*args, **kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 952, in __call__
    input_list)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1091, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 822, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py", line 863, in _infer_output_signature
    outputs = call_fn(inputs, *args, **kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py", line 1327, in _call_wrapper
    return self._call_wrapper(*args, **kwargs)
  File "/mnt/SSD/Marko/Dokumente/Uni/SoSe21/MA/LSTM_testproject/envs/LSTM_testproject/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py", line 1359, in _call_wrapper
    result = self.function(*args, **kwargs)
TypeError: haste_lstm() missing 1 required positional argument: 'training'

I construct the model with the following code:

inputs = k_l.Input(shape=(train_x.shape[1], train_x.shape[2]))
direction = 'unidirectional' if args.model == 'GRU' else 'bidirectional'
haste1 = haste.LSTM(args.hidden_size, direction=direction, zoneout=0.1, dropout=args.dropout_time)
fc1 = k_l.Dense(args.dense_layers[0], activation='relu', kernel_initializer='he_uniform')
dr1 = k_l.Dropout(0.2)
fc2 = k_l.Dense(1)

x, state = haste1(inputs, training=True)
x = fc1(inputs)
x = dr1(x)
outputs = fc2(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss=loss_func, optimizer=optimizer)
model_hist = model.fit(train_x, train_y, epochs=args.epochs, batch_size=args.batch_size, verbose=1,
                       validation_data=val_data, callbacks=keras_callbacks)

train_x numpy array shape is (21788, 1000, 4)
OS: Ubuntu 20.04
Python version: 3.7
Keras: 2.4.3
Tensorflow: 2.4.1
numpy: 1.19.5
GPU: GTX 1060
CUDA: 11.2

Normally I wouldn't post those error messages on github, but as the code would run without the haste layer, I suspect that the cause of the error lies somewhere close to it, and this repo seems to be the best place to ask and I didn't find any solutions elsewhere. I hope you can help me, I'd really like to try out your implementation for my dataset.

@thegodone
Copy link

thegodone commented Jun 16, 2021

did you get / solve this issue ?

@mahenning
Copy link
Author

Unfortunatelly not yet. I tried to run a sequential model with it, but for that it has to be an instance of class Layer.
For now I test different possibilities, as the error doesn't give me a starting point for fixing. But if I have a solution for it, I'll post it here.

@mahenning
Copy link
Author

Regarding the first error

'Cannot convert a symbolic Keras input/output to a numpy array. '
TypeError: Cannot convert a symbolic Keras input/output to a numpy array. This error may indicate that you're trying to pass a symbolic value to a NumPy call, which is not supported. Or, you may be trying to pass Keras symbolic inputs/outputs to a TF API that does not register dispatching, preventing Keras from automatically converting the API call to a lambda layer in the Functional Model.

In the Keras API docs I found this:

Note that even if eager execution is enabled, Input produces a symbolic tensor-like object (i.e. a placeholder). This symbolic tensor-like object can be used with lower-level TensorFlow ops that take tensors as inputs, as such:
this is a logistic regression in Keras
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x)
model = Model(x, y)

(This behavior does not work for higher-order TensorFlow APIs such as control flow and being directly watched by a tf.GradientTape).

The last line in which the error message is thrown is the ..keras/layers/core.py:1359, which is wrapped backprop.GradientTape.

with backprop.GradientTape(watch_accessed_variables=True) as tape,
variable_scope.variable_creator_scope(_variable_creator):
# We explicitly drop name arguments here,
# to guard against the case where an op explicitly has a
# name passed (which is susceptible to producing
# multiple ops w/ the same name when the layer is reused)
kwargs.pop('name', None)
result = self.function(*args, **kwargs)

Does that mean that the haste layers can't be used with an keras Input layer? If so, how can I create a model with them if neither the sequential nor the functional api (requires the input layer AFAIK) can be used? In your example on the main page the layer gets a tensor directly. Do I have to handle batches manually and input the data directly instead of letting the fit function work?

@thegodone
Copy link

thegodone commented Jun 23, 2021

import numpy as np
import tensorflow as tf
import haste_tf as haste

train_x = np.random.rand(500,40,20)
train_y = np.random.rand(500,40,1)

inputs = tf.keras.Input(shape=(train_x.shape[1], train_x.shape[2]))
haste1 = haste.LayerNormGRU(20, direction='unidirectional', zoneout=0.1, dropout=0.1)
fc1 = tf.keras.layers.Dense(60, activation='relu', kernel_initializer='he_uniform')
dr1 = tf.keras.layers.Dropout(0.2)
fc2 = tf.keras.layers.Dense(1)

x, state = haste1(inputs, training=True)
x = fc1(inputs)
x = dr1(x)
outputs = fc2(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

print(model.summary())

opt = tf.keras.optimizers.Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=opt)

model_hist = model.fit(train_x, train_y, epochs=10, batch_size=32, verbose=1)

this code is working for me

@mahenning
Copy link
Author

For me the error still remains. What python/Tensorflow version did you use for testing?

@thegodone
Copy link

thegodone commented Jun 23, 2021 via email

@thegodone
Copy link

Does it works for you ?

@mahenning
Copy link
Author

No.
Unfortunately I can't test this in another environment, as haste does not install the libhaste_tf.so when I run pip(3) install haste_tf on neither a second PC (py 3.6) nor my PC with a fresh venv with python 3.8. I tested a haste reinstall in my conda venv with python 3.7 and there the libhaste_tf.so is created (I also checked if it's getting deleted in between). Any idea on that? I understand that this is an unrelated problem to my initial post, so another issue? I really don't get what I'm doing wrong to get all this errors.

@mahenning
Copy link
Author

I tested it in a google colab (https://colab.research.google.com/drive/1mHuT35cnW5uOgulDjHGPf88gm74O9Caz?usp=sharing) and there the same error occurs.

@max1mn
Copy link

max1mn commented Jun 26, 2021

hello, everyone. actually, the 'real' bug here is this - 'TypeError: haste_lstm() missing 1 required positional argument: 'training'

it happens because keras 'wraps' this call in

h, c, _ = LIB.haste_lstm(

into TFOpLambda

LIB.haste_lstm expects 'training' argument, but when this call is wrapped, this argument is lost because in constructor of TFOpLambda the flag '_expects_training_arg' explicitly set to False
and because of that 'training' is removed

this is what happens in tf 2.5.0

p.s. obvious immediate fix - just comment this line in keras sources

@mahenning
Copy link
Author

Hey, thank you for figuring that out, I'll try that! But is there hope to get that fixed officially, without changing the keras core file?

@mw66
Copy link

mw66 commented Nov 6, 2023

import numpy as np import tensorflow as tf import haste_tf as haste

train_x = np.random.rand(500,40,20) train_y = np.random.rand(500,40,1)

inputs = tf.keras.Input(shape=(train_x.shape[1], train_x.shape[2])) haste1 = haste.LayerNormGRU(20, direction='unidirectional', zoneout=0.1, dropout=0.1) fc1 = tf.keras.layers.Dense(60, activation='relu', kernel_initializer='he_uniform') dr1 = tf.keras.layers.Dropout(0.2) fc2 = tf.keras.layers.Dense(1)

x, state = haste1(inputs, training=True) x = fc1(inputs) x = dr1(x) outputs = fc2(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

print(model.summary())

opt = tf.keras.optimizers.Adam(learning_rate=0.01) model.compile(loss='categorical_crossentropy', optimizer=opt)

model_hist = model.fit(train_x, train_y, epochs=10, batch_size=32, verbose=1)

this code is working for me

@thegodone actually, can you add this code as an example of how to mix haste layers into tf.keras model, I think new user will appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants