Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiler does not Seem to Output Timesteps in xplane.pb - "No step marker observed and hence the step time is unknown" from Tensorboard #66410

Open
stellarpower opened this issue Apr 24, 2024 · 3 comments
Assignees
Labels
comp:apis Highlevel API related issues comp:tensorboard Tensorboard related issues TF 2.15 For issues related to 2.15.x type:bug Bug WIP

Comments

@stellarpower
Copy link

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

source

TensorFlow version

2.15.0 (cuda120py39hb94c71b_3 from conda-forge)

Custom code

No

OS platform and distribution

Ubuntu Jammy in podman Container

Mobile device

No response

Python version

3.9.18

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

12.4 (cuda-cupti 12.4.127 @ h59595ed_1 from conda-forge)

GPU model and memory

RTX 3090, 24GiB

Current behavior?

I am in the process of writing a custom loss function, and trying to profile it to see where resources are currently used.

I have installed newer CUDA drives, the latest release of Tensorflow (2.15), the CUDA PTI libraries, and other dependencies needed for the Tensorboard profiler plugin.

I can run an example with the profiler and get what looks like reasonable data. With my own code, I get a warning back from _pywrap_profiler.xspace_to_tools_data() that no timesteps are contained in the file, and thus, some of the useful profiling information is absent/unusable. I cut back the example and found that if I use the MSE loss, the profile is complete; if I change to my own loss function, the timesteps are no longer output.

Given that the message is coming back from the core profiler library, and is contained within the encoded protocol buffers, I believe that this is an issue with the profiler and the main library, rather than the tensorboard utility or the profiling plugin.

The loss function is reasonably complicated, so I suspected at first the large files might be an issue. However, when reducing down to the toy example above, whilst there's a difference, the overhead in the files makes this difference much closer:
image
I previously had warning s that the processor has dropped frames due to insufficient buffer space, but with this trivially small data size, those are gone.

Standalone code to reproduce the issue

#!/usr/bin/env python
    
import os, sys, pprint, numpy
import tensorflow as tf


ScriptDirectory = os.path.dirname(os.path.realpath(__file__))

# https://github.com/stellarpower/Soft-DTW/tree/stellapower.AddFeature.BetterParallelisation
sys.path.insert(0, f"/path/to/Soft-DTW")
from softdtwkeras.SDTWLoss import SDTWLoss
#############################################################


from tensorflow.keras import layers
from tensorflow import keras

from softdtwkeras.SDTWLoss import SDTWLoss
from datetime import datetime


## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Epochs = 128


ProfilingTensorboardLogsDirectory = f"{ ScriptDirectory }/BugLogs/{ datetime.now().strftime('%Y-%m-%d_%H.%M.%S') }"

tboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir = ProfilingTensorboardLogsDirectory,
    histogram_freq = 1,
    
    profile_batch = (32, 96) # Should be Middle 50%
)



## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# Very small toy data, in case the file size is an issue.
x_train = numpy.zeros((1, 1, 1))


model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(batch_input_shape=(1, 1, 1), name='layers_flatten'),
])

model.compile(
      optimizer = 'adam',
    # This runs okay:
      # loss      =  'mse',
    # This results in no timestep info in the xplane file.
        loss      = SDTWLoss(gamma=0.5, BatchSize = 1)
)


model.fit(
    x         = x_train, 
    y         = x_train, # << trivial identity function
    epochs    = Epochs,
    callbacks = [tboard_callback]
)

Relevant log output

2024-04-24 23:32:23.290829: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-24 23:32:23.290868: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-24 23:32:23.291486: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-24 23:32:23.296288: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.



2024-04-24 23:32:26.851127: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-04-24 23:32:26.851153: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.
2024-04-24 23:32:26.857074: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:1883] Profiler found 1 GPUs
2024-04-24 23:32:26.861348: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-04-24 23:32:26.861419: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:2017] CUPTI activity buffer flushed


2024-04-24 23:32:26.871326: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.903533: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.903741: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.906062: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.906217: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.906350: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.982483: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.982679: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-04-24 23:32:26.982827: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355


2024-04-24 23:32:26.982929: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22199 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6


Epoch 1/128
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00
...
Epoch 31/128
1/1 [==============================] - 0s 6ms/step - loss: 0.0000e+00
Epoch 32/128
2024-04-24 23:32:30.249189: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-04-24 23:32:30.249218: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.
1/1 [==============================] - 0s 15ms/step - loss: 0.0000e+00
Epoch 33/128
1/1 [==============================] - 0s 7ms/step - loss: 0.0000e+00
...
Epoch 95/128
1/1 [==============================] - 0s 7ms/step - loss: 0.0000e+00
Epoch 96/128


2024-04-24 23:32:30.746769: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:70] Profiler session collecting data.
2024-04-24 23:32:30.747796: I external/local_xla/xla/backends/profiler/gpu/cupti_tracer.cc:2017] CUPTI activity buffer flushed
2024-04-24 23:32:30.792759: I external/local_xla/xla/backends/profiler/gpu/cupti_collector.cc:541]  GpuTracer has collected 4216 callback api events and 3979 activity events. 
2024-04-24 23:32:30.840791: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:131] Profiler session tear down.
2024-04-24 23:32:30.841246: I external/local_tsl/tsl/profiler/rpc/client/save_profile.cc:144] Collecting XSpace to repository: /BugLogs/2024-04-24_23.32.26/plugins/profile/2024_04_24_23_32_30/the-alchemist.xplane.pb


1/1 [==============================] - 0s 120ms/step - loss: 0.0000e+00
Epoch 97/128
...
@google-ml-butler google-ml-butler bot added the type:bug Bug label Apr 24, 2024
@tilakrayal tilakrayal added TF 2.15 For issues related to 2.15.x comp:tensorboard Tensorboard related issues comp:apis Highlevel API related issues labels Apr 25, 2024
@tilakrayal
Copy link
Contributor

@stellarpower
Thank you for reporting the issue. We are also trying to replicate the issue in our environment. Please allow sometime to reproduce and get back to you the same. Thank you!

@tilakrayal tilakrayal added the WIP label Apr 26, 2024
@stellarpower
Copy link
Author

stellarpower commented Apr 26, 2024

Thanks! I can try to push a container image if it helps, let me know if that's be useful. Or I guess I can try a colab notebook too.

@stellarpower
Copy link
Author

As a side note, is there an out-of-the-box way to create a colab notebook using a tensorflow nightly build? If not, given that's one of the main requests I see asked, before people ping some gists back and forth, I think that would be useful to have to support creating MREs. I assume I can pip install within the notebook, but I guess that having a template to work from would help save some time on both sides.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:apis Highlevel API related issues comp:tensorboard Tensorboard related issues TF 2.15 For issues related to 2.15.x type:bug Bug WIP
Projects
None yet
Development

No branches or pull requests

2 participants