Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] UMAP spectral initialization fails to preserve global structure. #5782

Open
kc-howe opened this issue Feb 23, 2024 · 1 comment
Open
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@kc-howe
Copy link

kc-howe commented Feb 23, 2024

Describe the bug

UMAP spectral initialization yields unexpected initial layout results. Consequently, global structure of the input data often is not preserved, even when a very high n_epochs parameter is used. If UMAP is used as a pre-processing step for clustering, this behavior can impact results significantly, depending on the geometry of the input data.

Steps/Code to reproduce bug

The code below provides a minimal example which exhibits the issue on scikit-learn's make_circles dataset. Since the results of the cuML UMAP implementation may vary despite setting a random_state seed, one may have to repeatedly run the code below in order to get a result which demonstrates the problematic behavior.

import cudf
import cupy
import plotly.express as px
from cuml import UMAP as GPU_UMAP
from plotly.subplots import make_subplots
from sklearn.datasets import make_circles
from umap import UMAP as CPU_UMAP

IMAGES_DIRECTORY = './images'

# Create datasets
X_circles, y_circles = make_circles(1000, random_state=42)

# CPU embedding
cpu_mapper = CPU_UMAP(
    n_neighbors=15,
    n_components=2,
    n_epochs=2000,
    init='spectral',
    learning_rate=1.0,
    random_state=0
)
cpu_embedding = cpu_mapper.fit_transform(X_circles)
cpu_embedding_df = cudf.DataFrame(cpu_embedding, columns=['x','y'])
cpu_embedding_df['label'] = y_circles.astype(str)

# GPU embedding
gpu_mapper = GPU_UMAP(
    n_neighbors=15,
    n_components=2,
    n_epochs=2000,
    init='spectral',
    learning_rate=1.0,
    random_state=0
)
gpu_embedding = gpu_mapper.fit_transform(cupy.asarray(X_circles))
gpu_embedding_df = cudf.DataFrame(gpu_embedding, columns=['x','y'])
gpu_embedding_df['label'] = y_circles.astype(str)

# Make plots
cpu_plot = px.scatter(
    cpu_embedding_df,
    x='x',
    y='y',
    color='label',
    title='Circles UMAP Embedding (CPU)'
)

gpu_plot = px.scatter(
    gpu_embedding_df,
    x='x',
    y='y',
    color='label',
    title='Circles UMAP Embedding (GPU)'
)

fig = make_subplots(1, 2, subplot_titles=['CPU Embedding', 'GPU Embedding'])
for trace in cpu_plot['data']:
    fig.append_trace(trace, row=1, col=1)
for trace in gpu_plot['data']:
    fig.append_trace(trace, row=1, col=2)

fig.update_layout(
    title_text='UMAP Spectral Init Global Structure Comparison',
    width=900,
    height=500,
    showlegend=False
)

fig.write_image(f'{IMAGES_DIRECTORY}/circles_embedding.png')

The resulting plot should look something like this:

circles_embedding

Expected behavior

Spectral initialization should yield similar global structure as the CPU version of UMAP in the initial layout and consequently after layout optimization. For the most part at least, spatially distinct connected components of the input data should stay separated after embedding; UMAP with spectral initialization should not make clustering such cases more difficult (see rings example below).

The results of CPU and GPU UMAP should not be identical, of course, as there are understandably some implementation differences, particularly in regards to spectral embedding. However, general behavior with respect to global structure preservation under spectral initialization should be the same.

The following figure demonstrates the extreme difference in spectral initialization behavior between CPU and GPU UMAP on scikit-learn's make_blobs. Note that it seems impossible to set n_epochs to precisely 0 in the cuML implementation without invoking default values, so a minimal value of 1 is used below. (Apologies for the small text, CPU results are in the top row, GPU results are in the bottom row.)

umap_embedding_cpu_gpu_blobs_epochs1_init=spectral

For a dataset like make_blobs, the initialization is very different, but the CPU and GPU results can generally be made the same by running at a high number of epochs. Here is a similar figure demonstrating the difference on a dataset comprised of three non-concentric rings, which cannot be improved by setting high n_epochs (see further below):

three_rings_dataset

umap_embedding_cpu_gpu_equidistant_rings_epochs=1_init=spectral

I do not suspect that the difference in spectral initialization results can always be resolved by raising the n_epochs parameter, as has been suggested as a potential resolution to similar issues reported by other users (e.g. #5474). Since the low-dimensional layout optimization only acts on KNN-localized edge weights, I don't expect that any number of epochs would promise the recovery of global structure. To be sure, we can check that at n_epochs values of 500 and 2000 we observe the same difference in behavior between CPU and GPU UMAP for the three rings dataset above.

umap_embedding_cpu_gpu_equidistant_rings_epochs=500_init=spectral

umap_embedding_cpu_gpu_equidistant_rings_epochs=2000_init=spectral

Environment details (please complete the following information):

  • Environment location: Bare metal
  • Linux Distro/Architecture: Pop!_OS 22.04 LTS x86_64
  • GPU Model/Driver: NVIDIA GeForce RTX 4070 / Driver 545
  • CUDA: 12.3
  • Method of cuDF & cuML install: conda (environment details below)
# packages in environment at /home/kenneth/miniconda3:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                       1_gnu  
archspec                  0.2.1              pyhd3eb1b0_0  
asttokens                 2.0.5              pyhd3eb1b0_0  
boltons                   23.0.0          py310h06a4308_0  
bzip2                     1.0.8                h7b6447c_0  
c-ares                    1.19.1               h5eee18b_0  
ca-certificates           2023.12.12           h06a4308_0  
caerus                    0.1.9                    pypi_0    pypi
certifi                   2024.2.2        py310h06a4308_0  
cffi                      1.16.0          py310h5eee18b_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
comm                      0.1.2           py310h06a4308_0  
conda                     24.1.1          py310h06a4308_0  
conda-content-trust       0.1.3           py310h06a4308_0  
conda-libmamba-solver     24.1.0             pyhd3eb1b0_0  
conda-package-handling    2.2.0           py310h06a4308_0  
conda-package-streaming   0.9.0           py310h06a4308_0  
contourpy                 1.1.1                    pypi_0    pypi
cryptography              39.0.1          py310h9ce1e76_2  
cycler                    0.12.1                   pypi_0    pypi
cython                    0.29.36                  pypi_0    pypi
debugpy                   1.6.7           py310h6a678d5_0  
decorator                 5.1.1              pyhd3eb1b0_0  
distro                    1.8.0           py310h06a4308_0  
dtwalign                  0.1.1                    pypi_0    pypi
exceptiongroup            1.2.0           py310h06a4308_0  
executing                 0.8.3              pyhd3eb1b0_0  
findpeaks                 2.5.4                    pypi_0    pypi
fmt                       9.1.0                hdb19cb5_0  
fonttools                 4.43.1                   pypi_0    pypi
hdbscan                   0.8.33                   pypi_0    pypi
icu                       73.1                 h6a678d5_0  
idna                      3.4             py310h06a4308_0  
iniconfig                 2.0.0                    pypi_0    pypi
ipykernel                 6.25.2             pyh2140261_0    conda-forge
ipython                   8.20.0          py310h06a4308_0  
jedi                      0.18.1          py310h06a4308_1  
joblib                    1.3.2                    pypi_0    pypi
jsonpatch                 1.32               pyhd3eb1b0_0  
jsonpointer               2.1                pyhd3eb1b0_0  
jupyter_client            8.6.0           py310h06a4308_0  
jupyter_core              5.5.0           py310h06a4308_0  
kiwisolver                1.4.5                    pypi_0    pypi
krb5                      1.20.1               h143b758_1  
ld_impl_linux-64          2.38                 h1181459_1  
libarchive                3.6.2                h6ac8c49_2  
libcurl                   8.5.0                h251f7ec_0  
libedit                   3.1.20230828         h5eee18b_0  
libev                     4.33                 h7f8727e_1  
libffi                    3.4.4                h6a678d5_0  
libgcc-ng                 11.2.0               h1234567_1  
libgomp                   11.2.0               h1234567_1  
libmamba                  1.5.6                haf1ee3a_0  
libmambapy                1.5.6           py310h2dafd23_0  
libnghttp2                1.57.0               h2d74bed_0  
libsodium                 1.0.18               h7b6447c_0  
libsolv                   0.7.24               he621ea3_0  
libssh2                   1.10.0               hdbd6064_2  
libstdcxx-ng              11.2.0               h1234567_1  
libuuid                   1.41.5               h5eee18b_0  
libxml2                   2.10.4               hf1b16e4_1  
llvmlite                  0.41.0                   pypi_0    pypi
lz4-c                     1.9.4                h6a678d5_0  
matplotlib                3.8.0                    pypi_0    pypi
matplotlib-inline         0.1.6           py310h06a4308_0  
menuinst                  2.0.2           py310h06a4308_0  
ncurses                   6.4                  h6a678d5_0  
nest-asyncio              1.5.6           py310h06a4308_0  
networkx                  3.1                      pypi_0    pypi
numba                     0.58.0                   pypi_0    pypi
numpy                     1.25.2                   pypi_0    pypi
openssl                   3.0.13               h7f8727e_0  
packaging                 23.1            py310h06a4308_0  
pandas                    2.1.1                    pypi_0    pypi
parso                     0.8.3              pyhd3eb1b0_0  
pcre2                     10.42                hebb0a14_0  
peakdetect                1.1                      pypi_0    pypi
pexpect                   4.8.0              pyhd3eb1b0_3  
pillow                    10.0.1                   pypi_0    pypi
pip                       22.3.1          py310h06a4308_0  
platformdirs              3.10.0          py310h06a4308_0  
plotly                    5.17.0                   pypi_0    pypi
pluggy                    1.4.0                    pypi_0    pypi
prompt-toolkit            3.0.43          py310h06a4308_0  
prompt_toolkit            3.0.43               hd3eb1b0_0  
psutil                    5.9.0           py310h5eee18b_0  
ptyprocess                0.7.0              pyhd3eb1b0_2  
pure_eval                 0.2.2              pyhd3eb1b0_0  
pybind11-abi              4                    hd3eb1b0_1  
pycosat                   0.6.6           py310h5eee18b_0  
pycparser                 2.21               pyhd3eb1b0_0  
pygments                  2.15.1          py310h06a4308_1  
pygraphviz                1.11                     pypi_0    pypi
pynndescent               0.5.10                   pypi_0    pypi
pyparsing                 3.1.1                    pypi_0    pypi
pytest                    8.0.0                    pypi_0    pypi
python                    3.10.13              h955ad1f_0  
python-dateutil           2.8.2              pyhd3eb1b0_0  
pytz                      2023.3.post1             pypi_0    pypi
pyzmq                     25.1.2          py310h6a678d5_0  
readline                  8.2                  h5eee18b_0  
reproc                    14.2.4               h295c915_1  
reproc-cpp                14.2.4               h295c915_1  
requests                  2.31.0          py310h06a4308_1  
ruamel.yaml               0.17.21         py310h5eee18b_0  
ruamel.yaml.clib          0.2.6           py310h5eee18b_1  
scikit-learn              1.3.2                    pypi_0    pypi
scipy                     1.11.3                   pypi_0    pypi
seaborn                   0.13.0                   pypi_0    pypi
setuptools                65.6.3          py310h06a4308_0  
six                       1.16.0             pyhd3eb1b0_1  
sqlite                    3.41.2               h5eee18b_0  
stack_data                0.2.0              pyhd3eb1b0_0  
tbb                       2021.10.0                pypi_0    pypi
tenacity                  8.2.3                    pypi_0    pypi
threadpoolctl             3.2.0                    pypi_0    pypi
tk                        8.6.12               h1ccaba5_0  
tomli                     2.0.1                    pypi_0    pypi
tornado                   6.3.3           py310h5eee18b_0  
tqdm                      4.65.0          py310h2f386ee_0  
traitlets                 5.7.1           py310h06a4308_0  
truststore                0.8.0           py310h06a4308_0  
tzdata                    2023.3                   pypi_0    pypi
umap-learn                0.5.4                    pypi_0    pypi
urllib3                   2.1.0           py310h06a4308_0  
wcwidth                   0.2.5              pyhd3eb1b0_0  
wget                      3.2                      pypi_0    pypi
wheel                     0.37.1             pyhd3eb1b0_0  
xarray                    2023.9.0                 pypi_0    pypi
xz                        5.4.5                h5eee18b_0  
yaml-cpp                  0.8.0                h6a678d5_0  
zeromq                    4.3.5                h6a678d5_0  
zlib                      1.2.13               h5eee18b_0  
zstandard                 0.19.0          py310h5eee18b_0  
zstd                      1.5.5                hc292b87_0  

Additional context

I initially felt this could be related to the known issue with the Laplacian eigenmaps solver mentioned in the comments of #5474, however the differences in results compared to the CPU solvers seem somewhat extreme.

I am also aware that CPU UMAP handles spectral layout of networks with multiple components somewhat differently than single-component networks. However, datasets which yield single-component graphs may exhibit the same behavior as above, e.g. when embedding a single ring.

I am happy to provide any additional code, examples, or environment details upon request.

Additionally, thank you all for the incredible work you do on this repository, and in particular for bringing UMAP to the GPU. You guys are phenomenal, and your efforts here are so deeply appreciated!

@kc-howe kc-howe added ? - Needs Triage Need team to review and classify bug Something isn't working labels Feb 23, 2024
@dantegd
Copy link
Member

dantegd commented Mar 20, 2024

Thanks for the issue @kc-howe! We have identified a few ill-behaviors and issues with spectral clustering from RAFT that affect UMAP in particular. We will be working on solving them, but we don't have an ETA yet, but is in our roadmap as we work on RAFT, cuML and cuVS in the next few releases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants