Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deserialization of a DeepTables model with a custom objective fails #88

Open
deburky opened this issue Oct 14, 2023 · 0 comments
Open

Comments

@deburky
Copy link

deburky commented Oct 14, 2023

System information

  • OS Platform and Distribution: Monterey 12.6.3 (M2)
  • Python version: 3.11
  • DeepTables version: 0.2.5
  • Other Python packages: keras==2.14.0 ml-dtypes==0.2.0 tensorboard==2.14.1 tensorflow-estimator==2.14.0 tensorflow-macros==2.14.0

Describe the current behavior

When trying to deserialize a DeepTables model with a different objective function, deserialization fails due to custom loss function being not defined. A potential workaround could be:

tf.keras.models.save_model(model.model, 'model.h5')
model2 = tf.keras.models.load_model('model.h5', custom_objects={
    'MultiColumnEmbedding' : MultiColumnEmbedding,
    'FM': FM,
    'FocalLoss': FocalLoss})

However, using DeepTables with this object is not straightforward and documentation doesn't cover this aspect.

Describe the expected behavior

It would be possible to deserialize any DeepTable model containing custom objectives without any effort. It is specifically important for deployment (for example, no clear way to use mlem package for example).

Standalone code to reproduce the issue

Link to Colab notebook

Let's initiate a virtual environment and install needed packages:

python3.11 -m venv .venv
source .venv/bin/activate
pip install pandas numpy scikit-learn tensorflow deeptables dask jupyter

After this, we can run the code below reproducing this issue.

"""
Original file is located at
    https://colab.research.google.com/drive/1XECkcRpqYCPlgRLCuqPn0BDFwKx8ujzc
"""

import tensorflow as tf
from tensorflow.keras.losses import Loss
from deeptables.models import deeptable, deepnets
from deeptables.datasets import dsutils
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

class FocalLoss(Loss):
    """
    We want to maximize the likelihood of correctly classifying challenging
    examples while giving less emphasis to well-classified examples.
    """
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

        # Cast y_true to float32
        y_true = tf.cast(y_true, tf.float32)

        ce_loss = - (y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
        pt = tf.math.exp(ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return tf.reduce_mean(focal_loss, axis=-1)

# Generate a synthetic dataset
X, y = make_classification(n_samples=10000, n_features=5, n_classes=2, n_informative=3, random_state=42)
X_trn, X_tst, y_trn, y_tst = train_test_split(X, y, stratify=y, test_size=0.33, random_state=62)

config = deeptable.ModelConfig(
    nets=deepnets.DeepFM,
    loss=FocalLoss(alpha=0.1, gamma=0.0),
    metrics=["AUC"],
    auto_discrete=True
)
dt_fl = deeptable.DeepTable(config=config)
model_fl, history_fl = dt_fl.fit(X_trn, y_trn, epochs=10)

result = dt_fl.evaluate(X_tst,y_tst, batch_size=512, verbose=0)
print(result)

preds_fl = dt_fl.predict_proba(X_tst)

import tempfile
tmpdir = tempfile.mkdtemp()
dt_fl.save(tmpdir)
model_load = deeptable.DeepTable.load(tmpdir)
model_load.evaluate(X_tst, y_tst)

The output produced is:

ValueError: Unknown loss function: 'FocalLoss'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.

Your support / feedback on this will be appreciated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant