Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask' #437

Open
CharlieFang1 opened this issue May 8, 2024 · 1 comment

Comments

@CharlieFang1
Copy link

When i try to fine tuning wizard-2 7b got the error: TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask'.

Full stack info as follow:
model loading
==((====))== Unsloth: Fast Llama patching release 2024.4
\ /| GPU: NVIDIA RTX A6000. Max memory: 47.536 GB. Platform = Linux.
O^O/ _/ \ Pytorch: 2.1.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\ / Bfloat16 = TRUE. Xformers = 0.0.22.post7. FA = True.
"--" Free Apache license: http://github.com/unslothai/unsloth
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:05<00:00, 1.91s/it]
model load successfully
Unsloth 2024.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.
max_steps is given, it will override any value given in num_train_epochs
==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1
\ /| Num examples = 87 | Num Epochs = 5
O^O/ _/ \ Batch size per device = 2 | Gradient Accumulation steps = 4
\ / Total batch size = 8 | Total steps = 50
"-
-" Number of trainable parameters = 41,943,040
0%| | 0/50 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/aiserver/project_sibyl/CogVLM/unsloth_train_no_instruction_wizardLM_2_7b.py", line 142, in
train()
File "/home/aiserver/project_sibyl/CogVLM/unsloth_train_no_instruction_wizardLM_2_7b.py", line 88, in train
trainer_stats = trainer.train()
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 361, in train
output = super().train(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
File "", line 361, in _fast_inner_training_loop
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/trainer.py", line 3161, in compute_loss
outputs = model(**inputs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/accelerate/utils/operations.py", line 825, in forward
return model_forward(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/accelerate/utils/operations.py", line 813, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/unsloth/models/llama.py", line 882, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask'

And my python 3.10 dependencies:
accelerate 0.29.3
aiofiles 23.2.1
aiohttp 3.9.1
aioprometheus 23.12.0
aiosignal 1.3.1
altair 5.2.0
annotated-types 0.6.0
anyio 4.2.0
anykeystore 0.2
apex 0.1
asgiref 3.7.2
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.1.0
backoff 2.2.1
bcrypt 4.1.2
beautifulsoup4 4.12.3
bitsandbytes 0.41.3.post2
blessed 1.20.0
blinker 1.7.0
blis 0.7.11
boto3 1.34.10
botocore 1.34.10
braceexpand 0.1.7
build 1.1.1
cachetools 5.3.2
catalogue 2.0.10
certifi 2023.11.17
chardet 5.2.0
charset-normalizer 3.3.2
chroma-hnswlib 0.7.3
chromadb 0.4.24
click 8.1.7
cloudpathlib 0.16.0
colorama 0.4.6
coloredlogs 15.0.1
comm 0.2.1
confection 0.1.4
contourpy 1.2.0
cpm-kernels 1.0.11
cryptacular 1.6.2
cycler 0.11.0
cymem 2.0.8
dataclasses-json 0.6.4
datasets 2.18.0
debugpy 1.6.7
decorator 5.1.1
decord 0.6.0
deepspeed 0.12.6
defusedxml 0.7.1
Deprecated 1.2.14
diffusers 0.15.1
dill 0.3.7
distro 1.9.0
docstring_parser 0.16
einops 0.7.0
emoji 2.10.1
en-core-web-sm 3.7.1
et-xmlfile 1.1.0
exceptiongroup 1.2.0
executing 2.0.1
fastapi 0.108.0
ffmpy 0.3.1
filelock 3.13.1
filetype 1.2.0
flash-attn 2.5.8
Flask 3.0.2
flatbuffers 23.5.26
fonttools 4.25.0
frozenlist 1.4.1
fsspec 2023.10.0
gitdb 4.0.11
GitPython 3.1.40
google-auth 2.28.1
googleapis-common-protos 1.62.0
gpustat 1.1.1
gradio 4.12.0
gradio_client 0.8.0
greenlet 3.0.3
grpcio 1.62.0
h11 0.14.0
hjson 3.1.0
httpcore 1.0.2
httptools 0.6.1
httpx 0.26.0
huggingface-hub 0.20.1
humanfriendly 10.0
hupper 1.12
idna 3.6
importlib-metadata 6.11.0
importlib-resources 6.1.1
ipykernel 6.29.3
ipython 8.22.1
itsdangerous 2.1.2
jedi 0.19.1
Jinja2 3.1.2
jmespath 1.0.1
joblib 1.3.2
jsonlines 4.0.0
jsonpatch 1.33
jsonpath-python 1.0.6
jsonpointer 2.4
jsonschema 4.20.0
jsonschema-specifications 2023.12.1
jupyter_client 8.6.0
jupyter_core 5.7.1
kiwisolver 1.4.4
kubernetes 29.0.0
langchain 0.1.10
langchain-community 0.0.25
langchain-core 0.1.28
langchain-openai 0.0.8
langchain-text-splitters 0.0.1
langcodes 3.3.0
langdetect 1.0.9
langsmith 0.1.14
loguru 0.7.2
lxml 5.1.0
markdown-it-py 3.0.0
MarkupSafe 2.1.3
marshmallow 3.21.0
matplotlib 3.8.0
matplotlib-inline 0.1.6
mdurl 0.1.2
mkl-fft 1.3.8
mkl-random 1.2.4
mkl-service 2.4.0
mmh3 4.1.0
monotonic 1.6
mpmath 1.3.0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.15
munkres 1.1.4
murmurhash 1.0.10
mypy-extensions 1.0.0
nest_asyncio 1.6.0
networkx 3.2.1
ninja 1.11.1.1
nltk 3.8.1
numpy 1.26.2
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-ml-py 12.535.133
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
onnxruntime 1.17.1
openai 1.13.3
opencv-python 4.8.1.78
openpyxl 3.1.2
opentelemetry-api 1.23.0
opentelemetry-exporter-otlp-proto-common 1.23.0
opentelemetry-exporter-otlp-proto-grpc 1.23.0
opentelemetry-instrumentation 0.44b0
opentelemetry-instrumentation-asgi 0.44b0
opentelemetry-instrumentation-fastapi 0.44b0
opentelemetry-proto 1.23.0
opentelemetry-sdk 1.23.0
opentelemetry-semantic-conventions 0.44b0
opentelemetry-util-http 0.44b0
orjson 3.9.15
overrides 7.7.0
packaging 23.2
pandas 2.1.4
parso 0.8.3
PasteDeploy 3.1.0
pbkdf2 1.3
pdfminer 20191125
peft 0.10.0
pexpect 4.9.0
pickleshare 0.7.5
Pillow 10.1.0
pip 23.3.1
plaster 1.1.2
plaster-pastedeploy 1.0.1
platformdirs 4.2.0
ply 3.11
posthog 3.4.2
preshed 3.0.9
prompt-toolkit 3.0.42
protobuf 3.20.3
psutil 5.9.0
ptyprocess 0.7.0
pulsar-client 3.4.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 14.0.2
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pycryptodome 3.20.0
pydantic 2.5.3
pydantic_core 2.14.6
pydeck 0.8.1b0
pydub 0.25.1
Pygments 2.17.2
pynvml 11.5.0
pyparsing 3.0.9
pypdf 4.1.0
PyPika 0.48.9
pyproject_hooks 1.0.0
PyQt5 5.15.10
PyQt5-sip 12.13.0
pyramid 2.0.2
pyramid-mailer 0.15.1
pytesseract 0.3.10
python-dateutil 2.8.2
python-docx 1.1.0
python-dotenv 1.0.0
python-iso639 2024.2.7
python-magic 0.4.27
python-multipart 0.0.6
python3-openid 3.2.0
pytz 2023.3.post1
PyYAML 6.0.1
pyzmq 25.1.2
quantile-python 1.1
rapidfuzz 3.6.1
ray 2.9.0
referencing 0.32.0
regex 2023.12.25
repoze.sendmail 4.4.1
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rpds-py 0.16.2
rsa 4.9
s3transfer 0.10.0
safetensors 0.4.1
scipy 1.11.4
seaborn 0.13.0
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 68.2.2
shellingham 1.5.4
shtab 1.7.1
sip 6.7.12
six 1.16.0
smart-open 6.4.0
smmap 5.0.1
sniffio 1.3.0
soupsieve 2.5
spacy 3.7.2
spacy-legacy 3.0.12
spacy-loggers 1.0.5
SQLAlchemy 2.0.24
srsly 2.4.8
sse-starlette 1.8.2
stack-data 0.6.2
starlette 0.32.0.post1
streamlit 1.31.1
streamlit-chat 0.1.1
SwissArmyTransformer 0.4.9
sympy 1.12
tabulate 0.9.0
tenacity 8.2.3
tensorboardX 2.6.2.2
thinc 8.2.2
tiktoken 0.6.0
timm 0.9.12
tokenizers 0.19.1
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.0
torch 2.1.2
torchvision 0.16.2
tornado 6.3.3
tqdm 4.66.1
traitlets 5.14.1
transaction 4.0
transformers 4.40.1
translationstring 1.4
triton 2.1.0
trl 0.8.6
typer 0.9.0
typing_extensions 4.10.0
typing-inspect 0.9.0
tyro 0.8.3
tzdata 2023.3
tzlocal 5.2
unsloth 2024.4
unstructured 0.12.5
unstructured-client 0.21.0
urllib3 2.0.7
uvicorn 0.24.0.post1
uvloop 0.19.0
validators 0.22.0
velruse 1.1.1
venusian 3.1.0
vllm 0.2.6
wasabi 1.1.2
watchdog 3.0.0
watchfiles 0.21.0
wcwidth 0.2.12
weasel 0.3.4
webdataset 0.2.86
WebOb 1.8.7
websocket-client 1.7.0
websockets 11.0.3
Werkzeug 3.0.1
wheel 0.43.0
wrapt 1.16.0
WTForms 3.1.1
wtforms-recaptcha 0.3.2
xformers 0.0.22.post7
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
zope.deprecation 5.0
zope.interface 6.1
zope.sqlalchemy 3.1

Thanks!

@danielhanchen
Copy link
Contributor

Hmm thats weird - do you know if Colab Mistral works?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants