Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to solve 'Input is not Invertible error'? #40

Open
arunpatro opened this issue Aug 3, 2018 · 10 comments
Open

How to solve 'Input is not Invertible error'? #40

arunpatro opened this issue Aug 3, 2018 · 10 comments

Comments

@arunpatro
Copy link

I am trying to train a GLOW mapping on a custom dataset. However while training, I frequently receive a tensorflow.python.framework.errors_impl.InvalidArgumentError: Input is not invertible error. Upon seeing the logs, I see that the training/validation stats have reached either inf or nan.

I then tried to just reproduce your results for celeba 256x256 Qualitatively. However, I still face such issues. I am lost as to how to debug. I downloaded the celeba-tfr dataset locally.

Command:

python train.py --problem celeba --image_size 256 --n_level 6 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --data_dir=./celeba-tfr --verbose --epochs_full_valid=1 --epochs_full_sample=1 --n_train=30 --n_test=30

Namespace:

Namespace(anchor_size=32, beta1=0.9, category='', dal=1, data_dir='./celeba-tfr', depth=32, direct_it
erator=True, epochs=1000000, epochs_full_sample=1, epochs_full_valid=1, epochs_warmup=10, flow_coupli
ng=0, flow_permutation=2, fmap=1, full_test_its=30, gradient_checkpointing=1, image_size=256, inferen
ce=False, learntop=True, local_batch_init=4, local_batch_test=1, local_batch_train=1, logdir='./logs'
, lr=0.001, n_batch_init=256, n_batch_test=50, n_batch_train=64, n_bins=32.0, n_bits_x=5, n_levels=6,
 n_sample=1, n_test=30, n_train=30, n_y=1, optimizer='adamax', pmap=16, polyak_epochs=1, problem='cel
eba', restore_path='', rnd_crop=False, seed=0, test_its=1, top_shape=[4, 4, 384], train_its=1, verbos
e=True, weight_decay=1.0, weight_y=0.0, width=512, ycond=False)

Trace:

Starting training. Logging to /home/ubuntu/glow_/logs/
epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg
0 179.9140625 [2.5411766 2.5411766 0.        1.       ]
1 64 1 0.0 179.9 88.8 177.1 445.7 [2.5411766 2.5411766 0.        1.       ] [2.7737396 2.7737396 0.
      1.       ]  *
64 5.25806736946106 [2.6743338 2.6743338 0.        1.       ]
2 128 2 0.2 5.3 36.1 161.6 203.0 [2.6743338 2.6743338 0.        1.       ] [nan nan  0.  1.]
128 4.962073087692261 [nan nan  0.  1.]
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/clie
nt/session.py", line 1322, in _do_call
    return fn(*args)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/clie
nt/session.py", line 1307, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/clie
nt/session.py", line 1409, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input is not invertible.
         [[Node: model_3/1/28/invconv/MatrixInverse = MatrixInverse[T=DT_FLOAT, adjoint=false, _devic
e="/job:localhost/replica:0/task:0/device:GPU:0"](model/1/28/invconv/W/read)]]
         [[Node: model_3/5/6/f1/l_1/Shape/_79621 = _Recv[client_terminated=false, recv_device="/job:l
ocalhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0",
send_device_incarnation=1, tensor_name="edge_9830_model_3/5/6/f1/l_1/Shape", tensor_type=DT_INT32, _d
evice="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "train.py", line 414, in <module>
    main(hps)
  File "train.py", line 163, in main
    train(sess, model, hps, logdir, visualise)
  File "train.py", line 274, in train
    visualise(epoch)
  File "train.py", line 50, in draw_samples
    x_samples.append(sample_batch(y, [.0]*n_batch))
  File "train.py", line 33, in sample_batch
    y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch]))
  File "/home/ubuntu/glow_/model.py", line 242, in sample
    return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std})
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
    run_metadata_ptr)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
    run_metadata)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input is not invertible.
         [[Node: model_3/1/28/invconv/MatrixInverse = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/1/28/invconv/W/read)]]
         [[Node: model_3/5/6/f1/l_1/Shape/_79621 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_9830_model_3/5/6/f1/l_1/Shape", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'model_3/1/28/invconv/MatrixInverse', defined at:
  File "train.py", line 414, in <module>
    main(hps)
  File "train.py", line 156, in main
    model = model.model(sess, hps, train_iterator, test_iterator, data_init)
  File "/home/ubuntu/glow_/model.py", line 239, in model
    x_sampled = f_sample(Y, m.eps_std)
  File "/home/ubuntu/glow_/model.py", line 232, in f_sample
    z = decoder(z, eps_std=eps_std)
  File "/home/ubuntu/glow_/model.py", line 97, in decoder
    z, _ = revnet2d(str(i), z, 0, hps, reverse=True)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 183, in func_with_args
    return func(*args, **current_args)
  File "/home/ubuntu/glow_/model.py", line 342, in revnet2d
    z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 183, in func_with_args
    return func(*args, **current_args)
  File "/home/ubuntu/glow_/model.py", line 411, in revnet2d_step
    "invconv", z, logdet, reverse=True)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/contrib/framework/python/ops/arg_scope.py", line 183, in func_with_args
    return func(*args, **current_args)
  File "/home/ubuntu/glow_/model.py", line 467, in invertible_1x1_conv
    _w = tf.matrix_inverse(w)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/gen_linalg_ops.py", line 1049, in matrix_inverse
    "MatrixInverse", input=input, adjoint=adjoint, name=name)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
    op_def=op_def)
  File "/home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Input is not invertible.
         [[Node: model_3/1/28/invconv/MatrixInverse = MatrixInverse[T=DT_FLOAT, adjoint=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](model/1/28/invconv/W/read)]]
         [[Node: model_3/5/6/f1/l_1/Shape/_79621 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_9830_model_3/5/6/f1/l_1/Shape", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

I suspected it is because of bad learning rates which might make the kernel non-invertible, I played with low LRs, but of no help.

@nuges01
Copy link

nuges01 commented Aug 16, 2018

I'm having this issue as well. I've been able to train on custom dataset without conditioning on class labels. However, if I set --ycond and weight_y 0.01, the gradients start by reducing, but eventually explode to inf, causing the error.

@arunpatro
Copy link
Author

I first run the experiment on an AWS p2.xlarge EC2 instance. It gave me these errors.
I re-run the experiment later on a NVIDIA Titan X. It ran smoothly. Converging on 64x64 images but ot converging for 256x256. Exact same hyperparams.

I doubt it can be solved. Its prone to bad random initialisations, that lead it to inf and NaN errors on the search space.

@omidsakhi
Copy link

How about adding "+ tf.eye(shape[3]) * 10e-4 " to this line:

https://github.com/openai/glow/blob/master/model.py#L451

? Does that make any difference?

@tatsuhiko-inoue
Copy link

I also experienced a similar error.
I avoided the error using the following modification.

diff --git a/model.py b/model.py
index b918ab0..68cb3fe 100644
--- a/model.py
+++ b/model.py
@@ -373,7 +373,7 @@ def revnet2d_step(name, z, logdet, hps, reverse):
                 h = f("f1", z1, hps.width, n_z)
                 shift = h[:, :, :, 0::2]
                 # scale = tf.exp(h[:, :, :, 1::2])
-                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
+                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) + 1e-10
                 z2 += shift
                 z2 *= scale
                 logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
@@ -393,7 +393,7 @@ def revnet2d_step(name, z, logdet, hps, reverse):
                 h = f("f1", z1, hps.width, n_z)
                 shift = h[:, :, :, 0::2]
                 # scale = tf.exp(h[:, :, :, 1::2])
-                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
+                scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) + 1e-10
                 z2 /= scale
                 z2 -= shift
                 logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
diff --git a/tfops.py b/tfops.py
index d978419..2e7c556 100644
--- a/tfops.py
+++ b/tfops.py
@@ -449,9 +449,9 @@ def gaussian_diag(mean, logsd):
     o.sample = mean + tf.exp(logsd) * o.eps
     o.sample2 = lambda eps: mean + tf.exp(logsd) * eps
     o.logps = lambda x: -0.5 * \
-        (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd))
+        (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / (tf.exp(2. * logsd) + 1e-10))
     o.logp = lambda x: flatten_sum(o.logps(x))
-    o.get_eps = lambda x: (x - mean) / tf.exp(logsd)
+    o.get_eps = lambda x: (x - mean) / (tf.exp(logsd) + 1e-10)
     return o

@nuges01
Copy link

nuges01 commented Aug 23, 2018

@tatsuhiko-inoue, Thanks for the suggestion. It didn't work for me though. Those modifications are under the condition elif hps.flow_coupling == 1 (affine coupling). I'm following the set of parameters for Conditional qualitative results for which flow_coupling is 0 (additive). For example:
python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
I just ran it after changing --flow_coupling to 1, but it still results in the problem.

@nuges01
Copy link

nuges01 commented Aug 23, 2018

@omidsakhi, that didn't solve it for me either, I'm afraid. Thanks.

@tatsuhiko-inoue
Copy link

When I execute glow, the gradient of "logsd" in gaussian_diag() may be NaN.
When the "logsd" is 45.0 or more, the gradient becomes NaN.

I was able to avoid NaN gradient by calculate the gradient of "x/exp(y)" collectively as follows.
But instead the loss has become unstable.

@tf.custom_gradient
def div_by_exp(x, y):
    exp_y = tf.exp(y) + 1e-10
    ret = x / exp_y
    def _grad(dy):
        return dy/exp_y, dy*-ret
    return ret, _grad

def gaussian_diag(mean, logsd):
        :
    o.logps = lambda x: -0.5 * (np.log(2 * np.pi) + 2. * logsd + div_by_exp((x - mean) ** 2, 2*logsd))
        :

@paulchou0309
Copy link

I met the issue same how to solve it?@tatsuhiko-inoue @nuges01 @arunpatro

@naturomics
Copy link

naturomics commented Feb 10, 2019

Hello guys, I found a solution for this 'not invertible' problem. During the training, the weighs of invertible 1x1 conv keeps increase to balance the log-determinant terms generated by invertible 1x1 conv and affine coupling layer/actnorm. This can be solved by adding an regularization term only for the weights of invertible 1x1 conv. In practice I use l2 regulariztion. But it's also worth mentioning that after adding regularization term, the number of epochs will slightly increase to converge to the same NLL.

I discussed it in our recent publication "Generative Model with Dynamic Linear Flow", which improves the performance of flow-based methods significantly and converges faster than Glow. Our code is here.

@AnastasisKratsios
Copy link

  • tf.eye(shape[3]) * 10e-4

I think you mean "+ tf.eye(3) * 10e-4"

shape[3] is not defined.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants