Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clustering does not accept input from CountVectorizer or TfidfVectorizer #5807

Open
erico-imgproj opened this issue Mar 16, 2024 · 3 comments
Labels
? - Needs Triage Need team to review and classify bug Something isn't working

Comments

@erico-imgproj
Copy link

Describe the bug
NLP clustering does not work properly. The code available in example works fine for classification tasks, but the clustering does not accept the required the output from classes like CountVectorizer or TfidfVectorizer.

This error is also happening when executing PCA on the results of CountVectorizer or TfidfVectorizer.

Steps/Code to reproduce bug

import cupy as cp

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

from dask_cuda import LocalCUDACluster
from dask.distributed import Client
import dask
from cuml.dask.common import to_sparse_dask_array
from cuml.dask.naive_bayes import MultinomialNB

# Create a local CUDA cluster
cluster = LocalCUDACluster()
client = Client(cluster)

# Load corpus
twenty_train = fetch_20newsgroups(subset='train',
                          shuffle=True, random_state=42)

cv = CountVectorizer()
xformed = cv.fit_transform(twenty_train.data).astype(cp.float32)
X = to_sparse_dask_array(xformed, client)

from cuml.dask.cluster import DBSCAN
model = DBSCAN(min_samples=5)#, gen_min_span_tree=True)
yhat = model.fit_predict(X)#Line where error happens

The code above works fine but, when the clustering algorithm is called it returns the following error

Key:       _func-56a8e14b-97b8-4cd4-8ce5-39f501002e85
Function:  _func
args:      (b'\x83\xc0\x8c\xc0\x9e\x9dD\xc0\x8a7\xb2\x03\x8f\xf19\xdc', dask.array<from-value, shape=(11314, 130107), dtype=float64, chunksize=(11314, 130107), chunktype=cupyx.csr_matrix>)
kwargs:    {'min_samples': 5, 'verbose': False}
Exception: "ValueError('setting an array element with a sequence.')"

2024-03-16 08:04:20,064 - distributed.worker - WARNING - Compute Failed
Key:       _func-43403761-3f49-4ad0-b1f9-3321d4ceb2f6
Function:  _func
args:      (b'\x83\xc0\x8c\xc0\x9e\x9dD\xc0\x8a7\xb2\x03\x8f\xf19\xdc', dask.array<from-value, shape=(11314, 130107), dtype=float64, chunksize=(11314, 130107), chunktype=cupyx.csr_matrix>)
kwargs:    {'min_samples': 5, 'verbose': False}
Exception: "ValueError('setting an array element with a sequence.')"

2024-03-16 08:04:20,328 - distributed.worker - WARNING - Compute Failed
Key:       _func-276ba2f7-0fcf-409b-994f-b9e224811d3a
Function:  _func
args:      (b'\x83\xc0\x8c\xc0\x9e\x9dD\xc0\x8a7\xb2\x03\x8f\xf19\xdc', dask.array<from-value, shape=(11314, 130107), dtype=float64, chunksize=(11314, 130107), chunktype=cupyx.csr_matrix>)
kwargs:    {'min_samples': 5, 'verbose': False}
Exception: "ValueError('setting an array element with a sequence.')"

2024-03-16 08:04:20,407 - distributed.worker - WARNING - Compute Failed
Key:       _func-75df8a0c-79b6-4cf0-821c-b6d995dc472a
Function:  _func
args:      (b'\x83\xc0\x8c\xc0\x9e\x9dD\xc0\x8a7\xb2\x03\x8f\xf19\xdc', dask.array<from-value, shape=(11314, 130107), dtype=float64, chunksize=(11314, 130107), chunktype=cupyx.csr_matrix>)
kwargs:    {'min_samples': 5, 'verbose': False}
Exception: "ValueError('setting an array element with a sequence.')"

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/erico/lab/packages_dask/cuml/dask/cluster/dbscan.py", line 160, in fit_predict
    self.fit(X, out_dtype)
  File "/home/erico/lab/packages_dask/cuml/internals/memory_utils.py", line 87, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/dask/cluster/dbscan.py", line 133, in fit
    wait_and_raise_from_futures(dbscan_fit)
  File "/home/erico/lab/packages_dask/cuml/dask/common/utils.py", line 164, in wait_and_raise_from_futures
    raise_exception_from_futures(futures)
  File "/home/erico/lab/packages_dask/cuml/dask/common/utils.py", line 152, in raise_exception_from_futures
    raise RuntimeError(
RuntimeError: 4 of 4 worker jobs failed: setting an array element with a sequence., setting an array element with a sequence., setting an array element with a sequence., setting an array element with a sequence.

In the case of PCA, the code added before the clustering task is the following

from cuml.decomposition import PCA
pca_float = PCA(n_components = 500)
pca_float.fit(X)

and the error generated is

  File "/home/erico/lab/packages_dask/cuml/internals/array.py", line 290, in __init__
    new_data = cur_xpy.frombuffer(data, dtype=dtype)
  File "/home/erico/lab/packages_dask/cupy/_creation/from_data.py", line 167, in frombuffer
    return asarray(numpy.frombuffer(*args, **kwargs))
TypeError: a bytes-like object is required, not 'Array'

During handling of the above exception, another exception occurred:

TypeError: float() argument must be a string or a real number, not 'csr_matrix'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/erico/lab/packages_dask/cuml/internals/api_decorators.py", line 188, in wrapper
    ret = func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/internals/api_decorators.py", line 393, in dispatch
    return self.dispatch_func(func_name, gpu_func, *args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/internals/api_decorators.py", line 190, in wrapper
    return func(*args, **kwargs)
  File "base.pyx", line 687, in cuml.internals.base.UniversalBase.dispatch_func
  File "pca.pyx", line 434, in cuml.decomposition.pca.PCA.fit
  File "/home/erico/lab/packages_dask/nvtx/nvtx.py", line 116, in inner
    result = func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/internals/input_utils.py", line 380, in input_to_cuml_array
    arr = CumlArray.from_input(
  File "/home/erico/lab/packages_dask/cuml/internals/memory_utils.py", line 87, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/nvtx/nvtx.py", line 116, in inner
    result = func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/internals/array.py", line 1114, in from_input
    arr = cls(X, index=index, order=requested_order, validate=False)
  File "/home/erico/lab/packages_dask/cuml/internals/memory_utils.py", line 87, in cupy_rmm_wrapper
    return func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/nvtx/nvtx.py", line 116, in inner
    result = func(*args, **kwargs)
  File "/home/erico/lab/packages_dask/cuml/internals/array.py", line 292, in __init__
    new_data = cur_xpy.asarray(data, dtype=dtype)
  File "/home/erico/lab/packages_dask/cupy/_creation/from_data.py", line 88, in asarray
    return _core.array(a, dtype, False, order, blocking=blocking)
  File "cupy/_core/core.pyx", line 2379, in cupy._core.core.array
  File "cupy/_core/core.pyx", line 2406, in cupy._core.core.array
  File "cupy/_core/core.pyx", line 2541, in cupy._core.core._array_default
ValueError: setting an array element with a sequence.

Expected behavior
The clustering and the PCA algorithms should return the clusters in a list, and another tabular structure data set for post processing.

Environment details (please complete the following information):

  • Environment location: Bare-metal
  • Linux Distro/Architecture: Distributor ID: Debian
    Description: Debian GNU/Linux 10 (buster)
    Release: 10
    Codename: buster
  • GPU Model/Driver: [Tesla V100S-PCI Driver Version: 520.61.05]
  • CUDA: 11.8
  • Method of cuDF & cuML install: pip
cubinlinker-cu11          0.3.0.post1
cucim-cu11                24.2.0
cuda-python               11.8.3
cudf-cu11                 24.2.2
cugraph-cu11              24.2.0
cuml-cu11                 24.2.0
cuproj-cu11               24.2.0
cupy-cuda11x              13.0.0
cuspatial-cu11            24.2.0
cuxfilter-cu11            24.2.0
dask                      2024.1.1
dask-cuda                 24.2.0
dask-cudf-cu11            24.2.2
dask-glm                  0.3.2
dask-ml                   2023.3.24
raft-dask-cu11            24.2.0
rapids-dask-dependency    24.2.0

Additional context
This error is related to this first mention #5805

@erico-imgproj erico-imgproj added ? - Needs Triage Need team to review and classify bug Something isn't working labels Mar 16, 2024
@dantegd
Copy link
Member

dantegd commented Mar 20, 2024

Thanks for the issue @erico-imgproj! I notice that you're using from sklearn.feature_extraction.text import CountVectorizer, have you tried using cuML's CountVectorizer? Have you run into issues running with it?

We'll also work on solving the issue you're seeing with Scikit's CountVectorizer, it should work as well.

@erico-imgproj
Copy link
Author

Hi @dantegd I tested both, but the error is still there. The lines up to the generation of the features come from an example available at the CUML website. It should work. The issue seems to be that the output of the CountVectorizer nor TFIDFVectorizer are not recognized by the clustering algorithms. If you try to run classification tasks, they will work fine.

@AndreasKarasenko
Copy link

AndreasKarasenko commented Apr 12, 2024

This also happens for CUML's RandomForest. However, models like Naive Bayes and SVC do work in the same setup. Is there a specific reason why RF can't deal with the csr_matrix?
Here is a minimal example, I tried both CountVectorizer and HashingVectorizer (CUML and SKlearn).

import time

import cudf
import cupy as cp
import numpy as np

# from xgboost import XGBClassifier
from cuml.dask.common import to_sparse_dask_array
from cuml.ensemble import RandomForestClassifier

# from dask_ml.feature_extraction.text import HashingVectorizer
from cuml.feature_extraction.text import CountVectorizer, HashingVectorizer

# from cuml.dask.naive_bayes import MultinomialNB as cuNB
from cuml.naive_bayes import MultinomialNB as cuNB
from cuml.svm import SVC as cuSVC
from cupyx.scipy.sparse import csr_matrix
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
from sklearn.datasets import fetch_20newsgroups

# Create a local CUDA cluster
cluster = LocalCUDACluster()
client = Client(cluster)

# Load corpus

twenty_train = fetch_20newsgroups(subset="train", shuffle=True, random_state=42)
twenty_train = cudf.DataFrame.from_dict(
    {"data": twenty_train.data, "target": twenty_train.target}
)
cv = HashingVectorizer()

xformed = cv.fit_transform(twenty_train.data).astype(np.float32)

X = csr_matrix(xformed).astype(cp.float32)
y = cp.asarray(twenty_train.target).astype(cp.int32)

from cuml.ensemble import RandomForestClassifier as cuRF

# Try NB
model = cuNB()
start = time.time()
model.fit(X, y) # works
end = time.time()
print("Time to train: ", end - start)

# Try RF
model = cuRF()
start = time.time()
model.fit(X, y) # fails
end = time.time()
print("Time to train: ", end - start)

I get the same errors as @erico-imgproj.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants