Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to finetune on TPU v3-8 nodes? It runs without error but does not seem to progress. #38

Open
eurka opened this issue Oct 5, 2019 · 5 comments

Comments

@eurka
Copy link

eurka commented Oct 5, 2019

Hi!

thanks for the great paper and for providing code and model. I am trying to finetune the model on a TPU v3-8 node in the Google cloud. I made the following changes:

  • I added optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) to training.py
  • I patched keras.py and then set use_tpu=True and batch_size=8.
  • I set num_cores_per_replica=8, iterations_per_loop=1 and added cluster=tf.contrib.cluster_resolver.TPUClusterResolver() in the call to tf.contrib.tpu.RunConfig. This should distribute the models across the 8 cores in a TPU. I found that with lower numbers for num_cores_per_replica I get an out-of-memory error. This is the exact code:
    run_config = tf.contrib.tpu.RunConfig( cluster=tf.contrib.cluster_resolver.TPUClusterResolver(), model_dir=args.model_dir, session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True), tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1, num_cores_per_replica=8, per_host_input_for_training=3))

With these changes I can get the training.py to run with the seq256_v1 model without error. However, it doesn't seem to be doing anything after the model has been compiled, initialized from the checkpoint and the batches are being fed to the TPU. Even with a batch_size of only 8 and a total of 256 TFRecords in the input file, it never completes. The output I get is

...
WARNING:tensorflow:Entity <bound method EncoderLayer.call of <transformer.EncoderLayer object at 0x7f95b350b110>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output. Cause: converting <bound method EncoderLayer.call of <transformer.EncoderLayer object at 0x7f95b350b110>>: AssertionError: Bad argument number for Name: 3, expecting 4
WARNING:tensorflow:Entity <bound method MultiHeadAttention.call of <transformer.MultiHeadAttention object at 0x7f95b350b150>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, export AUTOGRAPH_VERBOSITY=10) and attach the full output. Cause: converting <bound method MultiHeadAttention.call of <transformer.MultiHeadAttention object at 0x7f95b350b150>>: AttributeError: 'module' object has no attribute 'Num'
...
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU tpuk in state READY, and health HEALTHY.
...

The last WARNING line keeps repeating.

With Tensorboard I wasn't able to get a trace, which may indicate nothing is happening on the TPU.

By my simple calculation based on the numbers presented in the paper, I should be able to get 1024 (examples/batch) * 800,000 (# iterations) / 32 ( = 256/8, number of cores in TPU V3-256 Pod used in paper / number of cores in TPU v3-8 node) / 24 (hours) / 14 (days) / 3600 (seconds/hr) ~20 examples per second.

I have been able to run other (much smaller) Keras models in tf 1.14 on a TPU v3-8 using the same RunConfig, where I also parallelized the model across the 8 TPU cores.

Do you have any idea why the training does not seem to work (or at best is extremely slow)? Am I parallellizing the model across the 8 TPU cores in the correct way? How was this done for the paper?

Any help would be greatly appreciated!

Many thanks,
Kees

PS I get the same result when I add input_partition_dims=[[1, 1], [1, 1]] as an option to tpu_config.

@keskarnitish
Copy link
Contributor

Are you using GKE or GCE?
Also, are you fine-tuning the 256 model or the 512 one?

Can you try with the 256 version and tf-version.cloud-tpus.google.com: "1.14.1.dev20190518":

...
resolver = tf.contrib.cluster_resolver.TPUClusterResolver()
...
run_config = tf.contrib.tpu.RunConfig(
        cluster=resolver,
        model_dir='gs://PATH_TO_MODEL_FILE',
        session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True),
        tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=1000, num_cores_per_replica=1, input_partition_dims=[[1, 1], [1, 1]], per_host_input_for_training=3))

and in the patched file:

estimator = tf.contrib.tpu.TPUEstimator(keras_model_fn, use_tpu=True, train_batch_size=512, eval_batch_size=32

@eurka
Copy link
Author

eurka commented Oct 17, 2019

I am using GCE with the 256 version. Previously I used a TF 1.14 vm and a TF 1.14 v3-8 TPU node. Now I tried again with a TF1.14 vm and a TF 1.14.1.dev20190518 TPU v3-8 node (version was confirmed in the console) and the settings you provided, but I still get an OOM error.

@keskarnitish
Copy link
Contributor

Are you able to train a smaller model (& pointing it to empty model directory instead of the 48-layer model)?

@pgrandinetti
Copy link

I have run this and get the following error.
My code is here https://github.com/pgrandinetti/ctrl/tree/master/training_utils (notice there's one more line than you mentioned).

Related to #65

Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 413000 into gs://ctrl-tuning/sf-ctrl/seqlen256_v1.ckpt/model.ckpt.
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:741: load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
INFO:tensorflow:Initialized dataset iterators in 3 seconds
INFO:tensorflow:Installing graceful shutdown hook.
2019-12-03 17:23:52.358863: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:356] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.
INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']
INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR

NFO:tensorflow:Initialized TPU in 2 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
ERROR:tensorflow:Error recorded from outfeed: From /job:worker/replica:0/task:0:
Bad hardware status: 0x1
         [[node OutfeedDequeueTuple (defined at training.py:171) ]]

Original stack trace for u'OutfeedDequeueTuple':
  File "training.py", line 171, in <module>
    estimator_model.train(input_fn=input_fn, steps=args.iterations)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1188, in _train_model_default
    features, labels, ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2709, in _call_model_fn
    config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1146, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3023, in _model_fn
    host_ops = host_call.create_tpu_hostcall()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2031, in create_tpu_hostcall
    device_ordinal=ordinal_id)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_tpu_ops.py", line 3190, in outfeed_dequeue_tuple
    device_ordinal=device_ordinal, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

ERROR:tensorflow:Error recorded from training_loop: From /job:worker/replica:0/task:0:
Compilation failure: Ran out of memory in memory space hbm. Used 16.52G of 16.00G hbm. Exceeded hbm capacity by 530.53M.

Total hbm usage >= 16.52G:
    reserved        528.00M
    program           3.79G
    arguments        12.21G (100.0% utilization)

Output size 12.21G (100.0% utilization); shares 12.21G with arguments.

Program hbm requirement 3.79G:
    reserved          12.0K
    global            1.83M
    scoped           14.14M
    HLO temp          3.77G (100.0% utilization, 0.4% fragmentation (14.88M))

  Largest program allocations in hbm:

  1. Size: 1.18G
     Shape: f32[246790,1280]{1,0:T(8,128)}
     Unpadded size: 1.18G
     Extra memory due to padding: 10.0K (1.0x expansion)
     XLA label: %fusion.1736 = f32[246790,1280]{1,0:T(8,128)} fusion(f32[256,1280]{1,0:T(8,128)} %fusion.3870, f32[246534,1280]{1,0:T(8,128)} %fusion.2182.remat), kind=kLoop, calls=%fused_computation.578
     Allocation type: HLO temp
     ==========================

  2. Size: 1.17G
     Operator: op_type="Transpose" op_name="training/gradients/tied_embedding_softmax_1/transpose_grad/transpose"
     Shape: f32[246534,1280]{1,0:T(8,128)}
     Unpadded size: 1.17G
     Extra memory due to padding: 10.0K (1.0x expansion)
     XLA label: %fusion.2182.remat = f32[246534,1280]{1,0:T(8,128)} fusion(f32[256]{0:T(256)} %fusion.2085, f32[256]{0:T(256)} %fusion.12448.remat4, u32[256]{0:T(256)} %fusion.12464.remat2, pred[256]{0:T(256)E(32)} %get-tuple-element.44117, f32[256,246534]{0,1:T(8,128)} %...
     Allocation type: HLO temp
     ==========================

  3. Size: 6.06M
     Shape: (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, MORE STUFF HERE
    Unpadded size: 6.06M
     XLA label: %tuple.5919 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)...
     Allocation type: HLO temp
     ==========================

  4. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1320 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2701.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2700), kind=kLoop, calls=%fused_computation.162
     Allocation type: HLO temp
     ==========================

  5. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  6. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  7. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1298 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2569.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2568), kind=kLoop, calls=%fused_computation.140
     Allocation type: HLO temp
     ==========================

  8. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  9. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  10. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_27/sequential_27/dense_167/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2900 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  11. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  12. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  13. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_23/sequential_23/dense_143/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2892 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  14. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  15. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_35/sequential_35/dense_215/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2918 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  16. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1272 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2413.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2412), kind=kLoop, calls=%fused_computation.114
     Allocation type: HLO temp
     ==========================

  17. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1321 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2703.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2702), kind=kLoop, calls=%fused_computation.163
     Allocation type: HLO temp
     ==========================

  18. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.2 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  19. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_35/sequential_35/dense_215/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2918 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  20. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

@ofrikeywee
Copy link

I have run this and get the following error.
My code is here https://github.com/pgrandinetti/ctrl/tree/master/training_utils (notice there's one more line than you mentioned).

Related to #65

Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 413000 into gs://ctrl-tuning/sf-ctrl/seqlen256_v1.ckpt/model.ckpt.
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py:741: load (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
INFO:tensorflow:Initialized dataset iterators in 3 seconds
INFO:tensorflow:Installing graceful shutdown hook.
2019-12-03 17:23:52.358863: W tensorflow/core/distributed_runtime/rpc/grpc_session.cc:356] GrpcSession::ListDevices will initialize the session with an empty graph and other defaults because the session has not yet been created.
INFO:tensorflow:Creating heartbeat manager for ['/job:worker/replica:0/task:0/device:CPU:0']
INFO:tensorflow:Configuring worker heartbeat: shutdown_mode: WAIT_FOR_COORDINATOR

NFO:tensorflow:Initialized TPU in 2 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
WARNING:tensorflow:TPUPollingThread found TPU ctrl-tuning in state READY, and health HEALTHY.
ERROR:tensorflow:Error recorded from outfeed: From /job:worker/replica:0/task:0:
Bad hardware status: 0x1
         [[node OutfeedDequeueTuple (defined at training.py:171) ]]

Original stack trace for u'OutfeedDequeueTuple':
  File "training.py", line 171, in <module>
    estimator_model.train(input_fn=input_fn, steps=args.iterations)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1188, in _train_model_default
    features, labels, ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2709, in _call_model_fn
    config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1146, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3023, in _model_fn
    host_ops = host_call.create_tpu_hostcall()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 2031, in create_tpu_hostcall
    device_ordinal=ordinal_id)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_tpu_ops.py", line 3190, in outfeed_dequeue_tuple
    device_ordinal=device_ordinal, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
    self._traceback = tf_stack.extract_stack()

ERROR:tensorflow:Error recorded from training_loop: From /job:worker/replica:0/task:0:
Compilation failure: Ran out of memory in memory space hbm. Used 16.52G of 16.00G hbm. Exceeded hbm capacity by 530.53M.

Total hbm usage >= 16.52G:
    reserved        528.00M
    program           3.79G
    arguments        12.21G (100.0% utilization)

Output size 12.21G (100.0% utilization); shares 12.21G with arguments.

Program hbm requirement 3.79G:
    reserved          12.0K
    global            1.83M
    scoped           14.14M
    HLO temp          3.77G (100.0% utilization, 0.4% fragmentation (14.88M))

  Largest program allocations in hbm:

  1. Size: 1.18G
     Shape: f32[246790,1280]{1,0:T(8,128)}
     Unpadded size: 1.18G
     Extra memory due to padding: 10.0K (1.0x expansion)
     XLA label: %fusion.1736 = f32[246790,1280]{1,0:T(8,128)} fusion(f32[256,1280]{1,0:T(8,128)} %fusion.3870, f32[246534,1280]{1,0:T(8,128)} %fusion.2182.remat), kind=kLoop, calls=%fused_computation.578
     Allocation type: HLO temp
     ==========================

  2. Size: 1.17G
     Operator: op_type="Transpose" op_name="training/gradients/tied_embedding_softmax_1/transpose_grad/transpose"
     Shape: f32[246534,1280]{1,0:T(8,128)}
     Unpadded size: 1.17G
     Extra memory due to padding: 10.0K (1.0x expansion)
     XLA label: %fusion.2182.remat = f32[246534,1280]{1,0:T(8,128)} fusion(f32[256]{0:T(256)} %fusion.2085, f32[256]{0:T(256)} %fusion.12448.remat4, u32[256]{0:T(256)} %fusion.12464.remat2, pred[256]{0:T(256)E(32)} %get-tuple-element.44117, f32[256,246534]{0,1:T(8,128)} %...
     Allocation type: HLO temp
     ==========================

  3. Size: 6.06M
     Shape: (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[8192]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, MORE STUFF HERE
    Unpadded size: 6.06M
     XLA label: %tuple.5919 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)...
     Allocation type: HLO temp
     ==========================

  4. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1320 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2701.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2700), kind=kLoop, calls=%fused_computation.162
     Allocation type: HLO temp
     ==========================

  5. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  6. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  7. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1298 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2569.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2568), kind=kLoop, calls=%fused_computation.140
     Allocation type: HLO temp
     ==========================

  8. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  9. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  10. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_27/sequential_27/dense_167/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2900 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  11. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  12. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %while.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, ...
     Allocation type: HLO temp
     ==========================

  13. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_23/sequential_23/dense_143/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2892 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  14. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  15. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_35/sequential_35/dense_215/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2918 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  16. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1272 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2413.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2412), kind=kLoop, calls=%fused_computation.114
     Allocation type: HLO temp
     ==========================

  17. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.1321 = f32[1280,8,128]{2,1,0:T(8,128)} fusion(f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2703.remat2, f32[640,8,128]{2,1,0:T(8,128)} %dynamic-slice.2702), kind=kLoop, calls=%fused_computation.163
     Allocation type: HLO temp
     ==========================

  18. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.2 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

  19. Size: 5.00M
     Operator: op_type="ResourceApplyAdagrad" op_name="training/Adagrad/update_encoder/encoder_layer_35/sequential_35/dense_215/kernel/ResourceApplyAdagrad"
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %fusion.2918 = (f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,8,128]{2,1,0:T(8,128)}, f32[1280,...
     Allocation type: HLO temp
     ==========================

  20. Size: 5.00M
     Shape: f32[1280,8,128]{2,1,0:T(8,128)}
     Unpadded size: 5.00M
     XLA label: %wide_param.3 = (s32[]{:T(256)}, f32[]{:T(256)}, s32[]{:T(256)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,128)}, f32[1280]{0:T(1024)}, f32[200,8,128]{2,1,0:T(8,12...
     Allocation type: HLO temp
     ==========================

+1
Got the same output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants