Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug when stacking LSTM #8255

Open
kzay opened this issue Apr 22, 2024 · 4 comments
Open

Bug when stacking LSTM #8255

kzay opened this issue Apr 22, 2024 · 4 comments
Assignees
Labels
comp:layers type:bug Something isn't working

Comments

@kzay
Copy link

kzay commented Apr 22, 2024

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js):
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows 11
  • TensorFlow.js installed from (npm or script link): NPM
  • TensorFlow.js version (use command below): 4.18.0

Describe the current behavior

When creating a model stacking two LSTM ran into an error

Describe the expected behavior

  • Input data can be dummy to reproduce the probleme.
  • When removing the second LSTM everything is working.

Error during model training: Argument tensors passed to stack must be a Tensor[]orTensorLike[]

__________________________________________________________________________________________
Layer (type)                Input Shape               Output shape              Param #
==========================================================================================
lstm_1 (LSTM)               [[null,1,8]]              [null,1,100]              43600
__________________________________________________________________________________________
dropout_1 (Dropout)         [[null,1,100]]            [null,1,100]              0
__________________________________________________________________________________________
bidirectional_lstm (Bidirec [[null,1,100]]            [null,100]                160800
__________________________________________________________________________________________
dropout_2 (Dropout)         [[null,100]]              [null,100]                0
__________________________________________________________________________________________
dense_1 (Dense)             [[null,100]]              [null,1]                  101
==========================================================================================
Total params: 204501
Trainable params: 204501
Non-trainable params: 0
__________________________________________________________________________________________

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/CodePen/any notebook.

 const model = tf.sequential();
        model.add(
            tf.layers.lstm({
                units: 100,
                inputShape: [1, 8], // Flexible time steps, defined number of features
                returnSequences: true,
                kernelInitializer: "glorotUniform", // For the input kernel
                recurrentInitializer: "orthogonal", // Especially good for RNNs
                name: "lstm_1",
            })
        );

        model.add(tf.layers.dropout({ rate: 0.3, name: "dropout_1" }));
        model.add(
            tf.layers.bidirectional({
                layer: tf.layers.lstm({
                    units: 100,
                    returnSequences: false,
                    kernelInitializer: "glorotUniform",
                    recurrentInitializer: "orthogonal",
                    name: "lstm_2",
                }),
                name: "bidirectional_lstm",
                mergeMode: "ave",
            })
        );
        model.add(tf.layers.dropout({ rate: 0.3, name: "dropout_2" }));

        model.add(
            tf.layers.dense({
                units: 1,
                activation: "sigmoid",
                kernelInitializer: "glorotUniform",
                name: "dense_1",
            })
        );
        model.compile({
            optimizer: "adam",
            loss: "binaryCrossentropy",
            metrics: ["accuracy"],
        });

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

@kzay kzay added the type:bug Something isn't working label Apr 22, 2024
@gaikwadrahul8 gaikwadrahul8 self-assigned this Apr 22, 2024
@gaikwadrahul8
Copy link
Contributor

Hi, @kzay

Thank you for reporting this issue. I attempted to reproduce it using the following code snippet on my Mac M1 system and did not encounter the Argument tensors passed to stack must be a Tensor[] or TensorLike[] error. Since you're using a Windows 11 system, it's possible the issue might be specific to that platform. Could you confirm if there are any relevant differences in your environment or perhaps guide me on further steps to replicate the same behavior on my end?

import * as tf from "@tensorflow/tfjs";

const model = tf.sequential();
model.add(
  tf.layers.lstm({
    units: 100,
    inputShape: [1, 8], // Flexible time steps, defined number of features
    returnSequences: true,
    kernelInitializer: "glorotUniform", // For the input kernel
    recurrentInitializer: "orthogonal", // Especially good for RNNs
    name: "lstm_1",
  })
);

// Stacked LSTM layer
model.add(
  tf.layers.lstm({
    units: 100,
    returnSequences: true, // Maintain sequence for next LSTM
    kernelInitializer: "glorotUniform", // For the input kernel
    recurrentInitializer: "orthogonal", // Especially good for RNNs
    name: "lstm_2",
  })
);

model.add(tf.layers.dropout({ rate: 0.3, name: "dropout_1" }));
model.add(
  tf.layers.bidirectional({
    layer: tf.layers.lstm({
      units: 100,
      returnSequences: false,
      kernelInitializer: "glorotUniform",
      recurrentInitializer: "orthogonal",
      name: "lstm_3",
    }),
    name: "bidirectional_lstm",
    mergeMode: "ave",
  })
);
model.add(tf.layers.dropout({ rate: 0.3, name: "dropout_2" }));

model.add(
  tf.layers.dense({
    units: 1,
    activation: "sigmoid",
    kernelInitializer: "glorotUniform",
    name: "dense_1",
  })
);
model.compile({
  optimizer: "adam",
  loss: "binaryCrossentropy",
  metrics: ["accuracy"],
});
console.log(model.summary())

Output of above code snippet :

(base) gaikwadrahul-macbookpro:test-8255 gaikwadrahul$ node index.js

============================
Hi, looks like you are running TensorFlow.js in Node.js. To speed things up dramatically, install our node backend, visit https://github.com/tensorflow/tfjs-node for more details. 
============================
Orthogonal initializer is being called on a matrix with more than 2000 (40000) elements: Slowness may result.
Orthogonal initializer is being called on a matrix with more than 2000 (40000) elements: Slowness may result.
Orthogonal initializer is being called on a matrix with more than 2000 (40000) elements: Slowness may result.
Orthogonal initializer is being called on a matrix with more than 2000 (40000) elements: Slowness may result.
__________________________________________________________________________________________
Layer (type)                Input Shape               Output shape              Param #   
==========================================================================================
lstm_1 (LSTM)               [[null,1,8]]              [null,1,100]              43600     
__________________________________________________________________________________________
lstm_2 (LSTM)               [[null,1,100]]            [null,1,100]              80400     
__________________________________________________________________________________________
dropout_1 (Dropout)         [[null,1,100]]            [null,1,100]              0         
__________________________________________________________________________________________
bidirectional_lstm (Bidirec [[null,1,100]]            [null,100]                160800    
__________________________________________________________________________________________
dropout_2 (Dropout)         [[null,100]]              [null,100]                0         
__________________________________________________________________________________________
dense_1 (Dense)             [[null,100]]              [null,1]                  101       
==========================================================================================
Total params: 284901
Trainable params: 284901
Non-trainable params: 0
__________________________________________________________________________________________

Thank you for your cooperation and patience.

@kzay
Copy link
Author

kzay commented Apr 23, 2024

Thank @gaikwadrahul8 , the only major change done is fixing the lib folder of the NPM package as mentionned in an other ticket to fix an issue about the binding by copying the *.dll from napi v9 to v8

image

Error without the DLL :

Error: The specified module could not be found. ***\node_modules\@tensorflow\tfjs-node\lib\napi-v8\tfjs_binding.node

I also downgraded to Node 19 and dont need to make this fix anymore

@kzay
Copy link
Author

kzay commented Apr 27, 2024

Started from fresh and simple install.

Scenario 1 : Windows 11, NodeJS 19.9.0, Python 3.9.13
Scenario 2 : Docker Image (tensorflow/tensorflow:latest), NodeJS latest

Same results on both scenario: with only one lstm it's working well, but get the error when adding a second layers of lstm

 const model = tf.sequential();
        model.add(
            tf.layers.lstm({
                units: 50,
                inputShape: [1, numFeatures],
                returnSequences: false,
            })
        );

        model.add(
            tf.layers.dense({
                units: 1,
                activation: "sigmoid",
            })
        );
        model.compile({
            optimizer: "adam",
            loss: "binaryCrossentropy",
            metrics: ["accuracy"],
        });
Epoch 1 / 10
eta=0.0 =====================================================================================================================================================>
76ms 953us/step - acc=1.00 loss=0.0688 val_acc=1.00 val_loss=0.0786
Epoch 2 / 10
eta=0.0 =====================================================================================================================================================>
112ms 1405us/step - acc=1.00 loss=0.0642 val_acc=1.00 val_loss=0.0739
Epoch 3 / 10
eta=0.0 =====================================================================================================================================================>
109ms 1367us/step - acc=1.00 loss=0.0599 val_acc=1.00 val_loss=0.0697
Epoch 4 / 10
eta=0.0 =====================================================================================================================================================>
85ms 1062us/step - acc=1.00 loss=0.0561 val_acc=1.00 val_loss=0.0656
Epoch 5 / 10
eta=0.0 =====================================================================================================================================================>
75ms 934us/step - acc=1.00 loss=0.0523 val_acc=1.00 val_loss=0.0619
Epoch 6 / 10
eta=0.0 =====================================================================================================================================================>
69ms 861us/step - acc=1.00 loss=0.0490 val_acc=1.00 val_loss=0.0586
Epoch 7 / 10
eta=0.0 =====================================================================================================================================================>
76ms 945us/step - acc=1.00 loss=0.0460 val_acc=1.00 val_loss=0.0556
Epoch 8 / 10
eta=0.0 =====================================================================================================================================================>
71ms 889us/step - acc=1.00 loss=0.0435 val_acc=1.00 val_loss=0.0527
Epoch 9 / 10
eta=0.0 =====================================================================================================================================================>
68ms 856us/step - acc=1.00 loss=0.0409 val_acc=1.00 val_loss=0.0501
Epoch 10 / 10
eta=0.0 =====================================================================================================================================================>
70ms 873us/step - acc=1.00 loss=0.0386 val_acc=1.00 val_loss=0.0477
const model = tf.sequential();
        model.add(
            tf.layers.lstm({
                units: 50,
                inputShape: [1, numFeatures],
                returnSequences: true,
            })
        );
        model.add(
            tf.layers.lstm({
                units: 20,
                returnSequences: false,
            })
        );
        model.add(
            tf.layers.dense({
                units: 1,
                activation: "sigmoid",
            })
        );
        model.compile({
            optimizer: "adam",
            loss: "binaryCrossentropy",
            metrics: ["accuracy"],
        });
Epoch 1 / 10
Failed to train the model: Error: Argument tensors passed to stack must be a `Tensor[]` or `TensorLike[]`
Epoch 1 / 10
Failed to train the model: Error: Argument tensors passed to stack must be a `Tensor[]` or `TensorLike[]`

@kzay
Copy link
Author

kzay commented May 5, 2024

Also happening with GRU

        const gruModel = tf.sequential();
        // First Bidirectional GRU layer
        gruModel.add(
            tf.layers.bidirectional({
                layer: tf.layers.gru({
                    units: 64,
                    inputShape: [this.windowSize, this.numFeatures],
                    returnSequences: true, // Keep true to allow stacking of GRU or other RNN layers
                    kernelRegularizer: tf.regularizers.l1({ l1: 0.01 }),
                }),
                inputShape: [this.windowSize, this.numFeatures],
                mergeMode: "concat",
            })
        );
        gruModel.add(tf.layers.dropout({ rate: 0.2 })); // Dropout for regularization

        // Second GRU layer
        gruModel.add(
            tf.layers.gru({
                units: 32,
                returnSequences: false, // Set false to connect to a Dense output layer
                kernelRegularizer: tf.regularizers.l1({ l1: 0.01 }),
            })
        );
        gruModel.add(tf.layers.batchNormalization()); // Batch normalization for stability

        // Output layer
        gruModel.add(tf.layers.dense({ units: 1, activation: "linear" }));

        // Compile the model with an advanced optimizer
        gruModel.compile({
            optimizer: tf.train.adam(0.001), // Consider using a learning rate scheduler
            loss: "meanSquaredError",
            metrics: ["mse"],
        });
        return gruModel;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:layers type:bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants