Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NameError: name 'FlashMHA' is not defined #182

lamasJose opened this issue Apr 16, 2024 · 3 comments

NameError: name 'FlashMHA' is not defined #182

lamasJose opened this issue Apr 16, 2024 · 3 comments


Copy link

Im doing the installation without any problem on a conda environment, but when I try to run the multiomics integration example I receive the error "name 'FlashMHA' is not defined", also, at the begining the warnings
"/home/jmlamas/miniforge3/envs/scgpt2/lib/python3.10/site-packages/scgpt/model/ UserWarning: flash_attn is not installed
warnings.warn("flash_attn is not installed")
/home/jmlamas/miniforge3/envs/scgpt2/lib/python3.10/site-packages/scgpt/model/ UserWarning: flash_attn is not installed warnings.warn("flash_attn is not installed")"
However, doing pip list I obtain
(scgpt2) [jmlamas@cluster1-head1 ~]$ pip list
Package Version

absl-py 2.1.0
aiohttp 3.9.4
aiosignal 1.3.1
anndata 0.10.7
appdirs 1.4.4
array_api_compat 1.6
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
cached-property 1.5.2
cell-gears 0.0.2
certifi 2024.2.2
charset-normalizer 3.3.2
chex 0.1.86
click 8.1.7
contextlib2 21.6.0
contourpy 1.2.1
cycler 0.12.1
datasets 2.18.0
dcor 0.6
decorator 5.1.1
Deprecated 1.2.14
dill 0.3.8
docker-pycreds 0.4.0
docrep 0.3.2
einops 0.7.0
et-xmlfile 1.1.0
etils 1.7.0
exceptiongroup 1.2.0
executing 2.0.1
filelock 3.13.4
flash_attn 1.0.4
flax 0.8.2
fonttools 4.51.0
frozenlist 1.4.1
fsspec 2024.2.0
gitdb 4.0.11
GitPython 3.1.43
h5py 3.11.0
huggingface-hub 0.22.2
idna 3.7
igraph 0.11.4
importlib_resources 6.4.0
ipython 8.23.0
jax 0.4.26
jaxlib 0.4.26
jedi 0.19.1
Jinja2 3.1.3
joblib 1.4.0
kiwisolver 1.4.5
legacy-api-wrap 1.4
leidenalg 0.10.2
lightning-utilities 0.11.2
llvmlite 0.42.0
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.8.4
matplotlib-inline 0.1.7
mdurl 0.1.2
ml_collections 0.1.1
ml-dtypes 0.4.0
mpmath 1.3.0
msgpack 1.0.8
mudata 0.2.3
multidict 6.0.5
multipledispatch 1.0.0
multiprocess 0.70.16
natsort 8.4.0
nest-asyncio 1.6.0
networkx 3.3
numba 0.59.1
numpy 1.26.4
numpyro 0.14.0
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.1.105
openpyxl 3.1.2
opt-einsum 3.3.0
optax 0.2.2
orbax 0.1.7
orbax-checkpoint 0.5.9
packaging 24.0
pandas 2.2.2
parso 0.8.4
patsy 0.5.6
pexpect 4.9.0
pillow 10.3.0
pip 24.0
prompt-toolkit 3.0.43
protobuf 4.25.3
psutil 5.9.8
ptyprocess 0.7.0
pure-eval 0.2.2
pyarrow 15.0.2
pyarrow-hotfix 0.6
pydot 2.0.0
Pygments 2.17.2
pynndescent 0.5.12
pyparsing 3.1.2
pyro-api 0.1.2
pyro-ppl 1.9.0
python-dateutil 2.9.0.post0
pytorch-lightning 1.9.5
pytz 2024.1
PyYAML 6.0.1
requests 2.31.0
rich 13.7.1
scanpy 1.10.1
scgpt 0.2.1
scib 1.1.5
scikit-learn 1.4.2
scikit-misc 0.3.1
scipy 1.13.0
scvi-tools 0.20.3
seaborn 0.13.2
sentry-sdk 1.45.0
session_info 1.0.0
setproctitle 1.3.3
setuptools 69.5.1
six 1.16.0
smmap 5.0.1
stack-data 0.6.3
statsmodels 0.14.1
stdlib-list 0.10.0
sympy 1.12
tensorstore 0.1.56
texttable 1.7.0
threadpoolctl 3.4.0
toolz 0.12.1
torch 2.1.2
torchdata 0.7.1
torchmetrics 1.3.2
torchtext 0.16.2
tqdm 4.66.2
traitlets 5.14.2
triton 2.1.0
typing_extensions 4.11.0
tzdata 2024.1
umap-learn 0.5.6
urllib3 2.2.1
wandb 0.16.6
wcwidth 0.2.13
wheel 0.43.0
wrapt 1.16.0
xxhash 3.4.1
yarl 1.9.4
zipp 3.18.1

So flash_attn is correctly installed. Anyone knows what is the problem?

Copy link

Im facing the same issue!

Copy link

Hi, thank you for the question and sharing the your environment info. It looks like you have flash-attn 1.0.4 installed. So the warning at basically says the FlashMHA class is not imported. See here:

from flash_attn.flash_attention import FlashMHA
flash_attn_available = True
except ImportError:
import warnings
warnings.warn("flash_attn is not installed")

And in flash-attn 1.0.4 you should have the class, see here:

Therefore, I think this is still likely an installation issue. Could you try in your environment,

python -c "import flash_attn"
python -c "from flash_attn.flash_attention import FlashMHA"

These should tell you whether the package has been properly installed.

Copy link

In my case, I'm running on Colab, and apparently there is incompatibility between CUDA and flash_attn:

!python -c "import flash_attn"
!python -c "from flash_attn.flash_attention import FlashMHA"

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/", line 7, in <module>
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/", line 5, in <module>
    import flash_attn_cuda
ImportError: /usr/local/lib/python3.10/dist-packages/ undefined symbol: _ZN3c104cuda20CUDACachingAllocator9allocatorE

Do you have any suggestions on how to resolve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

3 participants