Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model ensembling with Keras Funcational API #75

Open
kylec123 opened this issue Dec 17, 2021 · 6 comments
Open

Model ensembling with Keras Funcational API #75

kylec123 opened this issue Dec 17, 2021 · 6 comments
Assignees

Comments

@kylec123
Copy link

I'm have some issues trying to ensemble a neural network with a random forest.

The example I am following is very much like this, but only one NN and one RF.
https://www.tensorflow.org/decision_forests/tutorials/model_composition_colab

I have some common preprocessing layers for each model and am able to successfully train the NN component and the RF component.

My problem is a multiclass classification problem. When I call model.predict(X) on the RF model, I am returned for each example a distribution of num_classes values, the same as what my NN model returns.

The problem is when I simply call the model via functional API, ex. model(input_tensors), I am returned as output something of size (batch, 1) instead of (batch, num_classes).

I want to piece the two models together like shown in the link above, but I cannot call tf.stack on tensors that are not the same size.

How do I get the call method of the RF model to return class distribution and not just a single value?

@achoum
Copy link
Collaborator

achoum commented Dec 18, 2021

Hi,

By default, a classification model outputs the probability of the positive class in the case of binary classification, and it outputs the probability of the individual classes in the case of multi class classification.

Instead, if the model is created with the advanced argument's predict_single_probability_for_binary_classification=False, the probabilities of the individual classes are returned in all cases.

The model's output is (or at least should) be the same with predict and call. If you observe a difference, could you try sharing a small reproductible example?

Here is an illustration of the possible configurations:

# A toy binary classification dataset.
binary_classification_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
    pd.DataFrame(
        {"feature":[0,1,2,3]*5,
         "label":[0,1,0,1]*5}),
        label="label")

# A toy binary multi-class classification dataset.
multi_class_classification_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
    pd.DataFrame(
        {"feature":[0,1,2,3]*5,
         "label":[0,1,2,3]*5}),
        label="label")

def first_dataset_batch_to_tensor(dataset):
  return next(dataset.as_numpy_iterator())[0]

print("Output shapes:")

print("\tpredict_single_probability_for_binary_classification=True (default)")

model = tfdf.keras.GradientBoostedTreesModel(verbose=0)
model.fit(binary_classification_dataset)
print("\t\tPredict; Binary classification:", model.predict(binary_classification_dataset).shape)
print("\t\tCall; Binary classification:", model(first_dataset_batch_to_tensor(binary_classification_dataset)).shape)


model = tfdf.keras.GradientBoostedTreesModel(verbose=0)
model.fit(multi_class_classification_dataset)
print("\t\tPredict; Multi-class classification:", model.predict(multi_class_classification_dataset).shape)
print("\t\tCall; Multi-class classification:", model(first_dataset_batch_to_tensor(multi_class_classification_dataset)).shape)

print("\tpredict_single_probability_for_binary_classification=False")

adv_args = tfdf.keras.AdvancedArguments(
    predict_single_probability_for_binary_classification=False
)

model = tfdf.keras.GradientBoostedTreesModel(verbose=0, advanced_arguments=adv_args)
model.fit(binary_classification_dataset)
print("\t\tPredict; Binary classification:", model.predict(binary_classification_dataset).shape)
print("\t\tCall; Binary classification; default:", model(first_dataset_batch_to_tensor(binary_classification_dataset)).shape)

model = tfdf.keras.GradientBoostedTreesModel(verbose=0, advanced_arguments=adv_args)
model.fit(multi_class_classification_dataset)
print("\t\tPredict; Multi-class classification:", model.predict(multi_class_classification_dataset).shape)
print("\t\tCall; Multi-class classification; default:", model(first_dataset_batch_to_tensor(multi_class_classification_dataset)).shape)

will output:

Output shapes:
	predict_single_probability_for_binary_classification=True (default)
		Predict; Binary classification: (20, 1)
		Call; Binary classification: (20, 1)
		Predict; Multi-class classification: (20, 4)
		Call; Multi-class classification: (20, 4)
	predict_single_probability_for_binary_classification=False
		Predict; Binary classification: (20, 2)
		Call; Binary classification; default: (20, 2)
		Predict; Multi-class classification: (20, 4)
		Call; Multi-class classification; default: (20, 4)

@kylec123
Copy link
Author

Thanks for the help, I see that it's possible to get multiclass probability outputs but I'm not too sure yet on how to modify this example to make it work.

I am going off of this (with modification to be a multiclass problem):
https://www.tensorflow.org/decision_forests/tutorials/model_composition_colab

I assumed I could do the exact same as this example in terms of general program flow....but I may be wrong.


import tensorflow_decision_forests as tfdf
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import math
import matplotlib.pyplot as plt
try:
  from wurlitzer import sys_pipes
except:
  from colabtools.googlelog import CaptureLog as sys_pipes
from IPython.core.magic import register_line_magic
from IPython.display import Javascript

def make_dataset(num_examples, num_features, num_classes, seed=1234):
  np.random.seed(seed)
  features = np.random.uniform(-1, 1, size=(num_examples, num_features))
  noise = np.random.uniform(size=(num_examples))

  labels = np.random.randint(0,num_classes,size=(num_examples))
  return features, labels.astype(int)

def make_tf_dataset(batch_size=64, **args):
  features, labels = make_dataset(**args)
  return tf.data.Dataset.from_tensor_slices(
      (features, labels)).batch(batch_size)


num_features = 10
num_classes = 8

train_dataset = make_tf_dataset(
    num_examples=2500, num_features=num_features, num_classes=num_classes, batch_size=64, seed=1234)
test_dataset = make_tf_dataset(
    num_examples=10000, num_features=num_features, num_classes=num_classes, batch_size=64, seed=5678)

# Input features.
raw_features = tf.keras.layers.Input(shape=(num_features,))

# Stage 1
# =======

# Common learnable pre-processing
preprocessor = tf.keras.layers.Dense(10, activation=tf.nn.relu6)
preprocess_features = preprocessor(raw_features)

# Stage 2
# =======

# Model #1: NN
m1_z1 = tf.keras.layers.Dense(5, activation=tf.nn.relu6)(preprocess_features)
m1_pred = tf.keras.layers.Dense(num_classes, activation=tf.nn.sigmoid)(m1_z1)

# Model #2: NN
m2_z1 = tf.keras.layers.Dense(5, activation=tf.nn.relu6)(preprocess_features)
m2_pred = tf.keras.layers.Dense(num_classes, activation=tf.nn.sigmoid)(m2_z1)


def seed_advanced_argument(seed):
  """Create a seed argument for a TF-DF model.

  TODO(gbm): Surface the "seed" argument to the model constructor directly.
  """
  return tfdf.keras.AdvancedArguments(
      yggdrasil_training_config=tfdf.keras.core.YggdrasilTrainingConfig(
          random_seed=seed))


# Model #3: DF
model_3 = tfdf.keras.RandomForestModel(
    num_trees=1000, advanced_arguments=seed_advanced_argument(1234))
m3_pred = model_3(preprocess_features)

# Model #4: DF
model_4 = tfdf.keras.RandomForestModel(
    num_trees=1000,
    #split_axis="SPARSE_OBLIQUE", # Uncomment this line to increase the quality of this model
    advanced_arguments=seed_advanced_argument(4567))
m4_pred = model_4(preprocess_features)

# Since TF-DF uses deterministic learning algorithms, you should set the model's
# training seed to different values otherwise both
# `tfdf.keras.RandomForestModel` will be exactly the same.

# Stage 3
# =======
mean_nn_only = tf.reduce_mean(tf.stack([m1_pred, m2_pred], axis=0), axis=0)
mean_nn_and_df = tf.reduce_mean(
    tf.stack([m1_pred, m2_pred, m3_pred, m4_pred], axis=0), axis=0)
# Keras Models
# ============
ensemble_nn_only = tf.keras.models.Model(raw_features, mean_nn_only)
ensemble_nn_and_df = tf.keras.models.Model(raw_features, mean_nn_and_df)

The above is enough to show the issue I was facing initially. I only changed the dataset creation to be multi class, and changed the output hidden units for the two NN models to be num_classes.

You'll get the error:

Warning:  The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)
Use /tmp/tmpmfxmf9f0 as temporary training directory
Warning:  The model was called directly (i.e. using `model(data)` instead of using `model.predict(data)`) before being trained. The model will only return zeros until trained. The output shape might change after training Tensor("inputs:0", shape=(None, 10), dtype=float32)

ValueError                                Traceback (most recent call last)
<ipython-input-4-6f884380391a> in <module>()
     87 mean_nn_only = tf.reduce_mean(tf.stack([m1_pred, m2_pred], axis=0), axis=0)
     88 mean_nn_and_df = tf.reduce_mean(
---> 89     tf.stack([m1_pred, m2_pred, m3_pred, m4_pred], axis=0), axis=0)
     90 
     91 # Keras Models

2 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

ValueError: Exception encountered when calling layer "tf.stack_2" (type TFOpLambda).

Dimension 1 in both shapes must be equal, but are 8 and 1. Shapes are [?,8] and [?,1].
	From merging shape 1 with other shapes. for '{{node tf.stack_2/stack_1}} = Pack[N=4, T=DT_FLOAT, axis=0](Placeholder, Placeholder_1, Placeholder_2, Placeholder_3)' with input shapes: [?,8], [?,8], [?,1], [?,1].

Call arguments received:
  • values=['tf.Tensor(shape=(None, 8), dtype=float32)', 'tf.Tensor(shape=(None, 8), dtype=float32)', 'tf.Tensor(shape=(None, 1), dtype=float32)', 'tf.Tensor(shape=(None, 1), dtype=float32)']
  • axis=0
  • name=stack

The warning says you need to fit the DF model before trying to use the output because the output size may change.

So, I moved it to after the fitting of the model(s) like below, because I'm not sure what else to do....:

ensemble_nn_only.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=["accuracy"])

ensemble_nn_only.fit(train_dataset, epochs=20, validation_data=test_dataset)

train_dataset_with_preprocessing = train_dataset.map(lambda x,y: (preprocessor(x), y))
test_dataset_with_preprocessing = test_dataset.map(lambda x,y: (preprocessor(x), y))

model_3.fit(train_dataset_with_preprocessing)
model_4.fit(train_dataset_with_preprocessing)

mean_nn_and_df = tf.reduce_mean(
    tf.stack([m1_pred, m2_pred, m3_pred, m4_pred], axis=0), axis=0)
ensemble_nn_and_df = tf.keras.models.Model(raw_features, mean_nn_and_df)

The above will fit the NN's and the RF's, but will error out when trying to stack the outputs....

ValueError                                Traceback (most recent call last)
<ipython-input-26-15c6b2e5cd72> in <module>()
      1 mean_nn_and_df = tf.reduce_mean(
----> 2     tf.stack([m1_pred, m2_pred, m3_pred, m4_pred], axis=0), axis=0)
      3 ensemble_nn_and_df = tf.keras.models.Model(raw_features, mean_nn_and_df)

2 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

ValueError: Exception encountered when calling layer "tf.stack_5" (type TFOpLambda).

Dimension 1 in both shapes must be equal, but are 8 and 1. Shapes are [?,8] and [?,1].
	From merging shape 1 with other shapes. for '{{node tf.stack_5/stack_1}} = Pack[N=4, T=DT_FLOAT, axis=0](Placeholder, Placeholder_1, Placeholder_2, Placeholder_3)' with input shapes: [?,8], [?,8], [?,1], [?,1].

Call arguments received:
  • values=['tf.Tensor(shape=(None, 8), dtype=float32)', 'tf.Tensor(shape=(None, 8), dtype=float32)', 'tf.Tensor(shape=(None, 1), dtype=float32)', 'tf.Tensor(shape=(None, 1), dtype=float32)']
  • axis=0
  • name=stack

Is there something wrong when how I'm trying to adjust this example to multi class?

@kylec123
Copy link
Author

I'm still interested in this solution. Anyone have an idea on what is wrong here?

@some-guy1
Copy link

It seems this is still an issue and there is not a clear way to combine multi-class RF with multi-class neural network.

The suggested:

adv_args = tfdf.keras.AdvancedArguments(
    predict_single_probability_for_binary_classification=False
)

doesnt work since it only converts a binary output from 1 column to two column and does nothing for many classes.

There is the multitask option, which enables the RF to output for a multi-class, but this still doesnt work since upon creation of the graph combination with neural network and RF, the RF will still output [None, 1] until it is trained, which prevents it from being combined with a multi-class neural network.

@mohammad69h94
Copy link

I'm still interested in this solution. Anyone have an idea on what is wrong here?

Hi, I have the same problem. Could you find a solution?

@mohammad69h94
Copy link

It seems this is still an issue and there is not a clear way to combine multi-class RF with multi-class neural network.

The suggested:

adv_args = tfdf.keras.AdvancedArguments(
    predict_single_probability_for_binary_classification=False
)

doesnt work since it only converts a binary output from 1 column to two column and does nothing for many classes.

There is the multitask option, which enables the RF to output for a multi-class, but this still doesnt work since upon creation of the graph combination with neural network and RF, the RF will still output [None, 1] until it is trained, which prevents it from being combined with a multi-class neural network.

There isn't a solution yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants