Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError when running baselines/imagenet/sngp.py #329

Open
batzner opened this issue Apr 15, 2021 · 5 comments
Open

RuntimeError when running baselines/imagenet/sngp.py #329

batzner opened this issue Apr 15, 2021 · 5 comments

Comments

@batzner
Copy link

batzner commented Apr 15, 2021

Dear uncertainty-baseline authors,

I am trying to run the SNGP training on ImageNet using uncertainty-baselines/baselines/imagenet/sngp.py.

It errors during the execution of the first training step with the following message:

 RuntimeError: `merge_call` called while defining a new graph or a tf.function.
This can often happen if the function `fn` passed to `strategy.run()` 
contains a nested `@tf.function`, and the nested `@tf.function` contains 
a synchronization point, such as aggregating gradients (e.g, optimizer.apply_gradients),
or if the function `fn` uses a control flow statement which contains a synchronization 
point in the body. Such behaviors are not yet supported. Instead, please avoid 
nested `tf.function`s or control flow statements that may potentially cross a
synchronization boundary, for example, wrap the `fn` passed to `strategy.run` 
or the entire `strategy.run` inside a `tf.function` or move the control flow out of `fn`

This is the stack trace:

RuntimeError: in user code:

    .../lib/uncertainty-baselines/baselines/imagenet/sngp_tmp.py:290 step_fn  *
        model.layers[-1].reset_covariance_matrix()
    ../edward2/edward2/tensorflow/layers/random_feature.py:219 reset_covariance_matrix  *
        self._gp_cov_layer.reset_precision_matrix()
    ../edward2/edward2/tensorflow/layers/random_feature.py:363 reset_precision_matrix  *
        precision_matrix_reset_op = self.precision_matrix.assign(
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:685 assign  **
        return values_util.on_write_assign(self, value, use_locking=use_locking,
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py:33 on_write_assign
        return var._update(  # pylint: disable=protected-access
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:827 _update
        return self._update_replica(update_fn, value, **kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:897 _update_replica
        return _on_write_update_replica(self, update_fn, value, **kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/values.py:71 _on_write_update_replica
        return ds_context.get_replica_context().merge_call(
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2715 merge_call
        return self._merge_call(merge_fn, args, kwargs)
    .../venv/lib/python3.8/site-packages/tensorflow/python/distribute/mirrored_run.py:432 _merge_call
        raise RuntimeError(

It seems like the self.precision_matrix.assign call in edward2/edward2/tensorflow/layers/random_feature.py causes this error, because it is executed inside the strategy.run call of a tf.function.

What can I do to fix this?

@dustinvtran
Copy link
Member

@jereliu

Hi @batzner! Were you able to resolve this?

@batzner
Copy link
Author

batzner commented Apr 28, 2021

Unfortunately, I was not able to resolve this. I saw in GitHub issues in other repos that people who got the same error message were able to resolve it by specifying synchronization=tf.VariableSynchronization.ON_READ in the self.add_weight(...) call (see below).

But I am not sure whether the on-read-synchronization is the correct behavior in this case.

self.precision_matrix = (
    self.add_weight(
        name='gp_precision_matrix',
        shape=(gp_feature_dim, gp_feature_dim),
        dtype=self.dtype,
        initializer=tf.keras.initializers.Identity(self.ridge_penalty),
        trainable=False,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA))

(from https://github.com/google/edward2/blob/807bd74d93c607a5a4030c4ef7debecf89f8b6ab/edward2/tensorflow/layers/random_feature.py#L337)

@znado
Copy link
Collaborator

znado commented May 11, 2021

Hey! Sorry for the delay.

What was the command you were using to run the code, and what was the GPU/version of CUDA/version of TF? Also, have you been able to reproduce this at the current HEAD?

@gpleiss
Copy link
Contributor

gpleiss commented Sep 7, 2021

I get the exact same error. Fresh install (nightly version of tf - 9/7/21), CUDA 11.2.

@psiden
Copy link

psiden commented Jan 11, 2022

I also get this error with TF 2.7

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants