Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras model pickle-able but tf.keras model not pickle-able #34697

Closed
Edwin-Koh1 opened this issue Nov 29, 2019 · 60 comments
Closed

Keras model pickle-able but tf.keras model not pickle-able #34697

Edwin-Koh1 opened this issue Nov 29, 2019 · 60 comments
Assignees
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug

Comments

@Edwin-Koh1
Copy link

Edwin-Koh1 commented Nov 29, 2019

System information

  • Windows 10
  • Tensorflow 2.0 (CPU)
  • joblib 0.14.0
  • Python 3.7.5
  • Keras 2.3.1

Hello everybody! This is my first post so please forgive me if I have missed something. So I'm trying to use a genetic algorithm to train and evaluate multiple NN architectures so I need to parallelize them on a multi-core CPU. Therefore I have used joblib to try to parallelize this. However, I was stuck on my tf.keras code because it wasn't pickleable. After many hours of debugging I finally realised that the tf.keras models are not pickleable whereas keras models are.

Describe the current behavior
The code below works but if you replaced keras with tf.keras, there will be an error:
Could not pickle the task to send it to the workers.

Describe the expected behavior
Moving forward, tf.keras should be replacing keras and therefore tf.keras should also be pickleable.

Code to reproduce the issue

#The following is a simple code to illustrate the problem:
from joblib import Parallel, delayed
import keras
import tensorflow as tf

def test():
    model = keras.models.Sequential()
    return

Parallel(n_jobs=8)(delayed(test)(i) for i in range(10)) #this works as intended

def test_tf():
    model = tf.keras.models.Sequential()
    return

Parallel(n_jobs=8)(delayed(test_tf)(i) for i in range(10)) #this will spit out the error above

Other comments
I guess a quick fix would just be to replace all the existing code with tf.keras to just keras but seeing as keras support will be discontinued and absorbed by Tensorflow 2.0, I think this should be fixed.

@Edwin-Koh1 Edwin-Koh1 changed the title Keras model pickleable but tf.keras model not pickleable Keras model pickle-able but tf.keras model not pickle-able Nov 29, 2019
@ravikyram ravikyram self-assigned this Dec 2, 2019
@ravikyram ravikyram added comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 labels Dec 2, 2019
@ravikyram
Copy link
Contributor

@Edwin-Koh1

Can you please check with nightly version(!pip install tf-nightly==2.1.0dev20191201 ) and see if the error still persists. There are lot of performance improvements in latest nightly versions. Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Dec 2, 2019
@ravikyram ravikyram added the type:support Support issues label Dec 18, 2019
@ravikyram
Copy link
Contributor

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@hartikainen
Copy link
Contributor

hartikainen commented Jan 13, 2020

@ravikyram I'm still seeing this issue on tensorflow==2.1.0:

import pickle

import tensorflow as tf


def main():
    model_1 = tf.keras.Sequential((
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation='linear'),
    ))

    _ = model_1(tf.random.uniform((15, 3)))

    model_2 = pickle.loads(pickle.dumps(model_1))

    for w1, w2 in zip(model_1.get_weights(), model_2.get_weights()):
        tf.debugging.assert_equal(w1, w2)


if __name__ == '__main__':
    main()

results in

Traceback (most recent call last):
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_model.py", line 25, in <module>
    main()
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_model.py", line 18, in main
    model_2 = pickle.loads(pickle.dumps(model_1))
TypeError: can't pickle weakref objects
$ pip freeze | grep "tf\|tensor"
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-estimator==2.1.0
tensorflow-probability==0.9.0
$ python --version
Python 3.7.5

@ravikyram ravikyram reopened this Jan 13, 2020
@ravikyram
Copy link
Contributor

I have tried on colab with TF version 2.1.0-rc2, 2.2.0-dev20200113 and was able to reproduce the issue.Please, find the gist here. Thanks!

@ravikyram ravikyram assigned ymodak and unassigned ravikyram Jan 13, 2020
@ravikyram ravikyram added type:bug Bug and removed type:support Support issues labels Jan 13, 2020
@ymodak ymodak added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Jan 14, 2020
@hartikainen
Copy link
Contributor

@ravikyram, should keras functional models be picklable too or not? I'd assume if Sequential models are then functional models should too be? Or does functional models have some properties that make them harder to pickle?

$ python -m tests.test_pickle_keras_functional_model
2020-01-17 16:47:08.567598: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-01-17 16:47:08.581327: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fa0a55aa6c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-01-17 16:47:08.581362: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Traceback (most recent call last):
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_functional_model.py", line 20, in <module>
    main()
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_functional_model.py", line 13, in main
    model_2 = pickle.loads(pickle.dumps(model_1))
TypeError: can't pickle _thread.RLock objects

@marfox
Copy link

marfox commented Mar 30, 2020

Hi everyone,
I'm trying to switch from standalone keras to tensorflow.keras as per the recommendation at https://keras.io/.
I'm hitting the same exception as #34697 (comment) with joblib (which uses pickle under the hood).

System information:

  • Debian 10 (buster)
  • Python 3.7.6
  • joblib 0.14.1
  • tensorflow 2.1.0

Script to reproduce:

import joblib
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
joblib.dump(model, 'model.pkl')

Output:

TypeError: can't pickle _thread.RLock objects

marfox pushed a commit to Wikidata/soweego that referenced this issue Apr 1, 2020
…) have new default learning rates"

This reverts commit f6edcd5.
The switch is blocked: we are hitting tensorflow/tensorflow#34697.
Left a comment at tensorflow/tensorflow#34697 (comment)
@epetrovski
Copy link

epetrovski commented May 4, 2020

Here's a fix adapted from http://zachmoshe.com/2017/04/03/pickling-keras-models.html intended for solving the same issue back when Keras models used to not be pickleable.

import pickle
import tempfile
from tensorflow.keras.models import Sequential, load_model, save_model, Model
from tensorflow.keras.layers import Dense

# Hotfix function
def make_keras_picklable():
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            save_model(self, fd.name, overwrite=True)
            model_str = fd.read()
        d = {'model_str': model_str}
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = load_model(fd.name)
        self.__dict__ = model.__dict__


    cls = Model
    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

@leandrocouto
Copy link

@epetrovski Should I call this code whenever I'm about to pickle a model or can I just call it at the beginning of my application (before creating the model)?

@epetrovski
Copy link

@epetrovski Should I call this code whenever I'm about to pickle a model or can I just call it at the beginning of my application (before creating the model)?

You can definitely just call it once at the beginning of your app after importing tensorflow.keras.models.Model. Executing the function adds two new methods __getstate__()and __setstate__() to the tensorflow.keras.models.Model class so it should work every time you want to pickle a member of the updated tf.keras Model class - ie. your own model.

@adriangb
Copy link
Contributor

adriangb commented May 12, 2020

Here is an alternative to @epetrovski 's answer that does not require saving to a file:

import pickle

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils


def unpack(model, training_config, weights):
    restored_model = deserialize(model)
    if training_config is not None:
        restored_model.compile(
            **saving_utils.compile_args_from_training_config(
                training_config
            )
        )
    restored_model.set_weights(weights)
    return restored_model

# Hotfix function
def make_keras_picklable():

    def __reduce__(self):
        model_metadata = saving_utils.model_metadata(self)
        training_config = model_metadata.get("training_config", None)
        model = serialize(self)
        weights = self.get_weights()
        return (unpack, (model, training_config, weights))

    cls = Model
    cls.__reduce__ = __reduce__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

Source: https://docs.python.org/3/library/pickle.html#object.__reduce__

I feel like maybe this could be added to Model? Are there any cases where this would not work?

@sushreebarsa
Copy link
Contributor

@Edwin-Koh1 Could you please refer to the link1 ,link2 and let us know if it helps ? Thanks!

@sushreebarsa sushreebarsa added stat:awaiting response Status - Awaiting response from author TF 2.5 Issues related to TF 2.5 and removed TF 2.2 Issues related to TF 2.2 labels Sep 11, 2021
@Edwin-Koh1
Copy link
Author

Edwin-Koh1 commented Sep 13, 2021

Hey all, I believe the issue was resolved by @adriangb in #39609 and the keras pull #14748. Should be fixed in Keras and Tensorflow 2.6.0 but have not tested this yet.

@ageron
Copy link
Contributor

ageron commented Sep 14, 2021

@Edwin-Koh1 , indeed, the problem seems fixed in TF 2.6.0! 👍

This now works (while it failed in TF 2.5.1):

import joblib
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([Dense(1, input_shape=[42], activation='sigmoid')])
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
joblib.dump(model, 'model.pkl')

@adriangb
Copy link
Contributor

Glad we could fix it for you @ageron! It should also work with Functional models by the way.

@Edwin-Koh1 should we close the issue?

@ageron
Copy link
Contributor

ageron commented Sep 14, 2021

It works with joblib in TF 2.6.0, but not with pickle, however:

import pickle
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([Dense(1, input_shape=[42], activation='sigmoid')])
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
with open('model.pkl', 'wb') as f:
  pickle.dump(model, f)

Raises:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-e391f4844d65> in <module>()
      6 model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
      7 with open('model.pkl', 'wb') as f:
----> 8   pickle.dump(model, f)

TypeError: can't pickle weakref objects

@adriangb
Copy link
Contributor

Can you try with tf-nightly? I'm wondering if the fix was actually not in 2.6.0 and your example just happens to be broken in 2.5.1. In this colab notebook, on tf-nightly, your example runs fine.

@sushreebarsa
Copy link
Contributor

sushreebarsa commented Sep 19, 2021

@Edwin-Koh1 Could you please let us know if this issue is resolved for you ? If it is resolved please feel free to move this ticket to closed status .Thank you!

@sushreebarsa sushreebarsa added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting response Status - Awaiting response from author labels Sep 19, 2021
@ageron
Copy link
Contributor

ageron commented Sep 20, 2021

@adriangb you're right, both the joblib and pickle code examples work with tf-nightly (while only the joblib example worked in 2.6.0). 👍

@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Sep 27, 2021
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@raj-gupta1
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

Issue resolved

@AhmedCoolProjects
Copy link

Here's a fix adapted from http://zachmoshe.com/2017/04/03/pickling-keras-models.html intended for solving the same issue back when Keras models used to not be pickleable.

import pickle
import tempfile
from tensorflow.keras.models import Sequential, load_model, save_model, Model
from tensorflow.keras.layers import Dense

# Hotfix function
def make_keras_picklable():
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            save_model(self, fd.name, overwrite=True)
            model_str = fd.read()
        d = {'model_str': model_str}
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = load_model(fd.name)
        self.__dict__ = model.__dict__


    cls = Model
    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

That really worked for me, But I got an issue while using pickle to load the model, so I should to keep using tensorflow so the pickle load function can load the model.

image

@raghunathanp95
Copy link

It works with joblib in TF 2.6.0, but not with pickle, however:

import pickle
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential([Dense(1, input_shape=[42], activation='sigmoid')])
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
with open('model.pkl', 'wb') as f:
  pickle.dump(model, f)

Raises:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-e391f4844d65> in <module>()
      6 model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
      7 with open('model.pkl', 'wb') as f:
----> 8   pickle.dump(model, f)

TypeError: can't pickle weakref objects
import tensorflow as tf
tf.keras.models.save_model(model, filename)

Try this instead of pickle.dump or joblilb.dump

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.5 Issues related to TF 2.5 type:bug Bug
Projects
None yet
Development

No branches or pull requests