Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for tf.RaggedTensor Input #185

Open
mimxrt opened this issue Feb 16, 2021 · 10 comments
Open

Support for tf.RaggedTensor Input #185

mimxrt opened this issue Feb 16, 2021 · 10 comments
Labels
enhancement New feature or request

Comments

@mimxrt
Copy link

mimxrt commented Feb 16, 2021

Currentrly, Keras TCN does not support tf.RaggedTensor input. It would be very useful for speeding up training for variying length time series inputs. As I understand it, there is no way to batch sequences of different lengths except when using tf.RaggedTensor*. See the following minimal example of the current state:

*EDIT: Masking is another option, but still...

import tcn
import numpy as np
import tensorflow as tf
import tensorflow.keras as K

batch_size=2

X = [
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]),
    np.array([[1, 1], [2, 2], [3, 3]]),
    np.array([[1, 1]]),
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]]),
]
Y = np.array([6, 1, 2, 4])

ds_raw_X = tf.data.Dataset.from_tensor_slices(tf.ragged.constant(X, inner_shape=(2,)))
ds_raw_Y = tf.data.Dataset.from_tensor_slices(Y)

ds_raw = tf.data.Dataset.zip((ds_raw_X, ds_raw_Y))

print("Raw dataset:")
for e in ds_raw.take(2):
    print(f"{(e[0].shape, e[1].shape)}")
print()

ds = ds_raw.batch(batch_size)

print("Final dataset:")
for e in ds.take(2):
    print(f"{(e[0].shape, e[1].shape)}")
print()

m_in = K.Input(shape=(None, 2), batch_size=batch_size)
m_out = tcn.TCN(
    nb_filters=32,
    return_sequences=False
)(m_in)
m_out = K.layers.Dense(1)(m_out)

model = K.Model([m_in], m_out)
print(model.summary(), end="\n\n")

model.compile(optimizer=K.optimizers.Adam(), loss=K.losses.MeanSquaredError())

model.fit(ds, epochs=1)

Output:

Raw dataset:
(TensorShape([5, 2]), TensorShape([]))
(TensorShape([3, 2]), TensorShape([]))

Final dataset:
(TensorShape([2, None, 2]), TensorShape([2]))
(TensorShape([2, None, 2]), TensorShape([2]))

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(2, None, 2)]            0         
_________________________________________________________________
tcn (TCN)                    (2, 32)                   23136     
_________________________________________________________________
dense (Dense)                (2, 1)                    33        
=================================================================
Total params: 23,169
Trainable params: 23,169
Non-trainable params: 0
_________________________________________________________________
None

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    329     _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
--> 330                                       expand_composites)
    331   except (ValueError, TypeError) as e:

ValueError: The two structures don't have the same nested structure.

First structure: type=RaggedTensorSpec str=RaggedTensorSpec(TensorShape([None, None, 2]), tf.int32, 1, tf.int64)

Second structure: type=Tensor str=Tensor("input_1:0", shape=(2, None, 2), dtype=float32)

More specifically: Substructure "type=RaggedTensorSpec str=RaggedTensorSpec(TensorShape([None, None, 2]), tf.int32, 1, tf.int64)" is a sequence, while substructure "type=Tensor str=Tensor("input_1:0", shape=(2, None, 2), dtype=float32)" is not

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-1-52273b009edd> in <module>
     43 model.compile(optimizer=K.optimizers.Adam(), loss=K.losses.MeanSquaredError())
     44 
---> 45 model.fit(ds, epochs=1)

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    233           max_queue_size=max_queue_size,
    234           workers=workers,
--> 235           use_multiprocessing=use_multiprocessing)
    236 
    237       total_samples = _get_total_number_of_samples(training_data_adapter)

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_training_inputs(model, x, y, batch_size, epochs, sample_weights, class_weights, steps_per_epoch, validation_split, validation_data, validation_steps, shuffle, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    591         max_queue_size=max_queue_size,
    592         workers=workers,
--> 593         use_multiprocessing=use_multiprocessing)
    594     val_adapter = None
    595     if validation_data:

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    704       max_queue_size=max_queue_size,
    705       workers=workers,
--> 706       use_multiprocessing=use_multiprocessing)
    707 
    708   return adapter

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/data_adapter.py in __init__(self, x, y, sample_weights, standardize_function, **kwargs)
    700 
    701     if standardize_function is not None:
--> 702       x = standardize_function(x)
    703 
    704     # Note that the dataset instance is immutable, its fine to reusing the user

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in standardize_function(dataset)
    658         model.sample_weight_mode = getattr(model, 'sample_weight_mode', None)
    659 
--> 660       standardize(dataset, extract_tensors_from_dataset=False)
    661 
    662       # Then we map using only the tensor standardization portion.

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2381         is_dataset=is_dataset,
   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)
   2384 
   2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2445     flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
   2446     for (a, b) in zip(flat_inputs, flat_expected_inputs):
-> 2447       nest.assert_same_structure(a, b, expand_composites=True)
   2448 
   2449     if y is not None:

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    335                   "Entire first structure:\n%s\n"
    336                   "Entire second structure:\n%s"
--> 337                   % (str(e), str1, str2))
    338 
    339 

ValueError: The two structures don't have the same nested structure.

First structure: type=RaggedTensorSpec str=RaggedTensorSpec(TensorShape([None, None, 2]), tf.int32, 1, tf.int64)

Second structure: type=Tensor str=Tensor("input_1:0", shape=(2, None, 2), dtype=float32)

More specifically: Substructure "type=RaggedTensorSpec str=RaggedTensorSpec(TensorShape([None, None, 2]), tf.int32, 1, tf.int64)" is a sequence, while substructure "type=Tensor str=Tensor("input_1:0", shape=(2, None, 2), dtype=float32)" is not
Entire first structure:
.
Entire second structure:
.

I think this is expected to fail in all cases because I did not add the ragged=True parameter. When adding this parameter, the error is as expected (Layer tcn_1 does not support RaggedTensors as input):

m_in = K.Input(shape=(None, 2), batch_size=batch_size, ragged=True) # note the added ragged=True parameter
m_out = tcn.TCN(
    nb_filters=32,
    return_sequences=False
)(m_in)
m_out = K.layers.Dense(1)(m_out)

model = K.Model([m_in], m_out)
print(model.summary(), end="\n\n")

Output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-2-3821044ae8a6> in <module>
      3     nb_filters=32,
      4     return_sequences=False
----> 5 )(m_in)
      6 m_out = K.layers.Dense(1)(m_out)
      7 

~/dss/code-envs/python/tensorflow_2/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    740           raise ValueError('Layer %s does not support RaggedTensors as input. '
    741                            'Inputs received: %s. You can try converting your '
--> 742                            'input to an uniform tensor.' % (self.name, inputs))
    743 
    744         graph = backend.get_graph()

ValueError: Layer tcn_1 does not support RaggedTensors as input. Inputs received: tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 2), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(3,), dtype=int64)). You can try converting your input to an uniform tensor.
@philipperemy philipperemy added the enhancement New feature or request label Feb 17, 2021
@philipperemy
Copy link
Owner

@mimxrt not sure exactly why it does not work. This should work in theory:

m_in = K.Input(shape=(None, 2), batch_size=batch_size, ragged=True)

Ref: tensorflow/tensorflow#27170

@mimxrt
Copy link
Author

mimxrt commented Mar 10, 2021

I'm not sure why but it doesn't work for me. I also made sure to get the latest version of TensorFlow this time! Can you try and run this example?

import sys
import tcn
import numpy as np
import tensorflow as tf
import tensorflow.keras as K

batch_size=2

print(f"sys.version: {sys.version}")
print(f"tf.__version__: {tf.__version__}")

X = [
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]),
    np.array([[1, 1], [2, 2], [3, 3]]),
    np.array([[1, 1]]),
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]]),
]
Y = np.array([6, 1, 2, 4])

ds_raw_X = tf.data.Dataset.from_tensor_slices(tf.ragged.constant(X, inner_shape=(2,)))
ds_raw_Y = tf.data.Dataset.from_tensor_slices(Y)

ds_raw = tf.data.Dataset.zip((ds_raw_X, ds_raw_Y))

print("Raw dataset:")
for e in ds_raw.take(2):
    print(f"{(e[0].shape, e[1].shape)}")
print("...\n")

ds = ds_raw.batch(batch_size)

print("Final dataset:")
for e in ds.take(2):
    print(f"{(e[0].shape, e[1].shape)}")
print("...\n")

m_in = K.Input(shape=(None, 2), batch_size=batch_size, ragged=True) # note the added ragged=True parameter
m_out = tcn.TCN(
    nb_filters=32,
    return_sequences=False
)(m_in)
m_out = K.layers.Dense(1)(m_out)

model = K.Model([m_in], m_out)
print(model.summary(), end="\n\n")

model.compile(optimizer=K.optimizers.Adam(), loss=K.losses.MeanSquaredError())

model.fit(ds, epochs=1)

Output for me:

sys.version: 3.7.10 (default, Feb 26 2021, 13:06:18) [MSC v.1916 64 bit (AMD64)]
tf.__version__: 2.4.1
Raw dataset:
(TensorShape([5, 2]), TensorShape([]))
(TensorShape([3, 2]), TensorShape([]))
...

Final dataset:
(TensorShape([2, None, 2]), TensorShape([2]))
(TensorShape([2, None, 2]), TensorShape([2]))
...

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-c63b87df74e5> in <module>
     39     nb_filters=32,
     40     return_sequences=False
---> 41 )(m_in)
     42 m_out = K.layers.Dense(1)(m_out)
     43 

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    950     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
    951       return self._functional_construction_call(inputs, args, kwargs,
--> 952                                                 input_list)
    953 
    954     # Maintains info about the `Layer.call` stack.

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1089         # Check input assumptions set after layer building, e.g. input shape.
   1090         outputs = self._keras_tensor_symbolic_call(
-> 1091             inputs, input_masks, args, kwargs)
   1092 
   1093         if outputs is None:

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
    820       return nest.map_structure(keras_tensor.KerasTensor, output_signature)
    821     else:
--> 822       return self._infer_output_signature(inputs, args, kwargs, input_masks)
    823 
    824   def _infer_output_signature(self, inputs, args, kwargs, input_masks):

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
    861           # TODO(kaftan): do we maybe_build here, or have we already done it?
    862           self._maybe_build(inputs)
--> 863           outputs = call_fn(inputs, *args, **kwargs)
    864 
    865         self._handle_activity_regularization(inputs, outputs)

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\autograph\impl\api.py in wrapper(*args, **kwargs)
    668       except Exception as e:  # pylint:disable=broad-except
    669         if hasattr(e, 'ag_error_metadata'):
--> 670           raise e.ag_error_metadata.to_exception(e)
    671         else:
    672           raise

TypeError: in user code:

    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tcn\tcn.py:319 call  *
        x, skip_out = layer(K.cast(x, 'float32'), training=training)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tcn\tcn.py:155 call  *
        x = layer(x, training=training) if training_flag else layer(x)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:1012 __call__  **
        outputs = call_fn(inputs, *args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\layers\convolutional.py:246 call
        inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
        return target(*args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\array_ops.py:3422 pad
        result = gen_array_ops.pad(tensor, paddings, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\gen_array_ops.py:6484 pad
        "Pad", input=input, paddings=paddings, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\op_def_library.py:525 _apply_op_helper
        raise err
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\op_def_library.py:522 _apply_op_helper
        preferred_dtype=default_dtype)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\profiler\trace.py:163 wrapped
        return func(*args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:1540 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:339 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:265 constant
        allow_broadcast=True)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:283 _constant_impl
        allow_broadcast=allow_broadcast))
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\tensor_util.py:553 make_tensor_proto
        "supported type." % (type(values), values))

    TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 2), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(3,), dtype=int64)). Consider casting elements to a supported type.

@philipperemy
Copy link
Owner

@mimxrt yes it does not work for me too. But seems like it's a Keras issue here.

@mimxrt
Copy link
Author

mimxrt commented Mar 11, 2021

Thanks for testing! I created another script that shows that it works by simply using tf.keras.layers.LSTM instead of TCN (only one line difference):

import sys
import tcn
import numpy as np
import tensorflow as tf
import tensorflow.keras as K

batch_size=2

print(f"sys.version: {sys.version}")
print(f"tf.__version__: {tf.__version__}")

X = [
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]),
    np.array([[1, 1], [2, 2], [3, 3]]),
    np.array([[1, 1]]),
    np.array([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6]]),
]
Y = np.array([6, 1, 2, 4])

ds_raw_X = tf.data.Dataset.from_tensors(tf.ragged.constant(X, inner_shape=(2,)))
ds_raw_Y = tf.data.Dataset.from_tensors(tf.constant(Y))

ds_raw = tf.data.Dataset.zip((ds_raw_X, ds_raw_Y))
ds = ds_raw
ds = ds.unbatch().batch(batch_size)

print("\n\n===========================")
print("===== Works with LSTM =====")
print("===========================\n")

m_in = K.Input(shape=[None, 2], batch_size=batch_size, ragged=True)
m_out = K.layers.LSTM(32)(m_in)
m_out = K.layers.Dense(1)(m_out)
model = K.Model([m_in], m_out)
print(model.summary(), end="\n\n")
model.compile(optimizer=K.optimizers.Adam(), loss=K.losses.MeanSquaredError())
model.fit(ds, epochs=5)

print("\n\n===========================")
print("===== Fails with TCN ======")
print("===========================\n")

m_in = K.Input(shape=[None, 2], batch_size=batch_size, ragged=True)
m_out = tcn.TCN(32)(m_in)
m_out = K.layers.Dense(1)(m_out)
model = K.Model([m_in], m_out)
print(model.summary(), end="\n\n")
model.compile(optimizer=K.optimizers.Adam(), loss=K.losses.MeanSquaredError())
model.fit(ds, epochs=5)

Output:

sys.version: 3.7.10 (default, Feb 26 2021, 13:06:18) [MSC v.1916 64 bit (AMD64)]
tf.__version__: 2.4.1


===========================
===== Works with LSTM =====
===========================

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(2, None, 2)]            0         
_________________________________________________________________
lstm (LSTM)                  (2, 32)                   4480      
_________________________________________________________________
dense (Dense)                (2, 1)                    33        
=================================================================
Total params: 4,513
Trainable params: 4,513
Non-trainable params: 0
_________________________________________________________________
None

Epoch 1/5
2/2 [==============================] - 1s 5ms/step - loss: 15.8364
Epoch 2/5
2/2 [==============================] - 0s 4ms/step - loss: 15.1551
Epoch 3/5
2/2 [==============================] - 0s 5ms/step - loss: 14.4901
Epoch 4/5
2/2 [==============================] - 0s 5ms/step - loss: 13.8392
Epoch 5/5
2/2 [==============================] - 0s 5ms/step - loss: 13.2030


===========================
===== Fails with TCN ======
===========================

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-aef4100541a3> in <module>
     42 
     43 m_in = K.Input(shape=[None, 2], batch_size=batch_size, ragged=True)
---> 44 m_out = tcn.TCN(32)(m_in)
     45 m_out = K.layers.Dense(1)(m_out)
     46 model = K.Model([m_in], m_out)

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self, *args, **kwargs)
    950     if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
    951       return self._functional_construction_call(inputs, args, kwargs,
--> 952                                                 input_list)
    953 
    954     # Maintains info about the `Layer.call` stack.

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
   1089         # Check input assumptions set after layer building, e.g. input shape.
   1090         outputs = self._keras_tensor_symbolic_call(
-> 1091             inputs, input_masks, args, kwargs)
   1092 
   1093         if outputs is None:

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
    820       return nest.map_structure(keras_tensor.KerasTensor, output_signature)
    821     else:
--> 822       return self._infer_output_signature(inputs, args, kwargs, input_masks)
    823 
    824   def _infer_output_signature(self, inputs, args, kwargs, input_masks):

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
    861           # TODO(kaftan): do we maybe_build here, or have we already done it?
    862           self._maybe_build(inputs)
--> 863           outputs = call_fn(inputs, *args, **kwargs)
    864 
    865         self._handle_activity_regularization(inputs, outputs)

~\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\autograph\impl\api.py in wrapper(*args, **kwargs)
    668       except Exception as e:  # pylint:disable=broad-except
    669         if hasattr(e, 'ag_error_metadata'):
--> 670           raise e.ag_error_metadata.to_exception(e)
    671         else:
    672           raise

TypeError: in user code:

    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tcn\tcn.py:319 call  *
        x, skip_out = layer(K.cast(x, 'float32'), training=training)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tcn\tcn.py:155 call  *
        x = layer(x, training=training) if training_flag else layer(x)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:1012 __call__  **
        outputs = call_fn(inputs, *args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\keras\layers\convolutional.py:246 call
        inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs))
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\util\dispatch.py:201 wrapper
        return target(*args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\array_ops.py:3422 pad
        result = gen_array_ops.pad(tensor, paddings, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\ops\gen_array_ops.py:6484 pad
        "Pad", input=input, paddings=paddings, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\op_def_library.py:525 _apply_op_helper
        raise err
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\op_def_library.py:522 _apply_op_helper
        preferred_dtype=default_dtype)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\profiler\trace.py:163 wrapped
        return func(*args, **kwargs)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\ops.py:1540 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:339 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:265 constant
        allow_broadcast=True)
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\constant_op.py:283 _constant_impl
        allow_broadcast=allow_broadcast))
    C:\Users\user\Anaconda3\envs\tf2\lib\site-packages\tensorflow\python\framework\tensor_util.py:553 make_tensor_proto
        "supported type." % (type(values), values))

    TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=Tensor("Placeholder:0", shape=(None, 2), dtype=float32), row_splits=Tensor("Placeholder_1:0", shape=(3,), dtype=int64)). Consider casting elements to a supported type.

Is there anything that I need to configure properly to make it work with TCNs?

@philipperemy
Copy link
Owner

I'd say it's because we have a custom function to build() the network. The support for RaggedTensor should be about updating this function. That's my gut feeling.

@mimxrt
Copy link
Author

mimxrt commented Mar 11, 2021

Maybe it also has something todo with this issue: tensorflow/tensorflow#42417?

However, I also don't really know what's the technical difficulty as in my understanding everything that supports tf.keras.layers.Masking, should also be able to support tf.RaggedTensor. In the worst case you would have to manually pad the ragged tensor data to the maximum length and mask the padded values (granted, you would have to know / compute the length in advance).

@philipperemy
Copy link
Owner

Yes, I think you are right here. It's related. I'll see what I can do ;)

@maximgeller
Copy link

maximgeller commented Apr 21, 2021

However, I also don't really know what's the technical difficulty as in my understanding everything that supports tf.keras.layers.Masking, should also be able to support tf.RaggedTensor. In the worst case you would have to manually pad the ragged tensor data to the maximum length and mask the padded values (granted, you would have to know / compute the length in advance).

@mimxrt
Unfortunately I do not think this is true since in the issue that you mentioned it appears that convolutions cannot be applied to tensors with variable length. I've experienced this with ConvLSTM2D as well here: tensorflow/tensorflow#48678

@mimxrt
Copy link
Author

mimxrt commented Apr 22, 2021

@maximgeller I think you might have misunderstood my statement: I meant if masking is supported (i.e., everything that supports tf.keras.layers.Masking) then it should technically be possible to support ragged tensors. The latter still means that someone has to implement it.

@philipperemy
Copy link
Owner

philipperemy commented May 11, 2022

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants