Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieving standard deviation of predictions from ensemble #42

Open
jamiecollinson opened this issue Jul 27, 2021 · 6 comments
Open

Retrieving standard deviation of predictions from ensemble #42

jamiecollinson opened this issue Jul 27, 2021 · 6 comments
Labels
enhancement New feature or request

Comments

@jamiecollinson
Copy link

I'm working on an application where I'd like to retrieve the standard deviation of the predictions made by the trees within an ensemble (currently a tfdf.keras.RandomForestModel) to use as an estimate of the confidence of a given prediction.

It looks like I could do this by running a prediction on each individual tree with inspector.iterate_on_nodes() but is there a better way to do this via the main predict method, and if not would you consider this as an enhancement?

@achoum
Copy link
Collaborator

achoum commented Jul 29, 2021

Hi Jamie,

As you corrected noted, the API does not allow to obtain the individual tree predictions directly. Please feel free to create a feature request :). If we see traction, we will prioritize it.

In the mean time, there is new alternative solutions:

  1. Training multiple Random Forest models, each with one tree (while making sure to change the random seed).
  2. Training a single Random Forest model and dividing it per trees using the model inspector and model builder.

Using the model builder to generate the individual trees might be easier than running the inference manually in python.

While faster than solution 1., the solution 2. can still be slow on large models and datasets as the model deserialization+re-serialization in python is relatively slow. It would look like this:

# Train a Random Forest with 10 trees
model = tfdf.keras.RandomForestModel(num_trees=10)
model.fit(train_ds)

# Extract each of the 10 trees into a separate model.
inspector = model.make_inspector()

# TODO: Run in parallel.
models = []
for tree_idx, tree in enumerate(inspector.extract_all_trees()):
  print(f"Extract and export tree #{tree_idx}")

  # Create a RF model with a single tree.
  path = os.path.join(f"/tmp/model/{tree_idx}")
  builder = tfdf.builder.RandomForestBuilder(
      path=path,
      objective=inspector.objective())
  builder.add_tree(tree)
  builder.close()

  models.append(tf.keras.models.load_model(path))

# Compute the predictions of all the trees together.
class CombinedModel (tf.keras.Model):
  def call(self, inputs):
    # We assume that we have a binary classication model that returns a single
    # probability. In case of multi-class classification, use tf.stack instead.
    return tf.concat([ submodel(inputs) for submodel in models], axis=1)

print("Prediction of all the trees")
combined_model = CombinedModel()
all_trees_predictions = combined_model.predict(test_with_cast_ds)

See this colab for a full example.

Cheers,
M.

@jamiecollinson
Copy link
Author

Hi @achoum , as mentioned on the forum, thanks so much for this! Will try this out and raise a feature request :-)

@jfold
Copy link

jfold commented Oct 6, 2021

Any updates on this? :-) I am highly interested in this feature also

@mmiller8878
Copy link

I would also be interested in this as a feature!

Ideally the model should have the option to output the individual tree predictions so users could define their own confidence bounds for predictions.

@jamiecollinson could you re-open the issue and @achoum could you tag as an enhancement?

@rstz rstz reopened this Oct 19, 2022
@rstz rstz added the enhancement New feature or request label Oct 19, 2022
@rstz
Copy link
Collaborator

rstz commented Oct 19, 2022

Done :)

@Gail529
Copy link

Gail529 commented Oct 25, 2022

Any updates on this? It will be really helpful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants