TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask' #437

CharlieFang1 opened this issue May 8, 2024 · 1 comment


When i try to fine tuning wizard-2 7b got the error: TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask'.

Full stack info as follow:
model loading
==((====))== Unsloth: Fast Llama patching release 2024.4
\ /| GPU: NVIDIA RTX A6000. Max memory: 47.536 GB. Platform = Linux.
O^O/ _/ \ Pytorch: 2.1.2+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\ / Bfloat16 = TRUE. Xformers = 0.0.22.post7. FA = True.
"--" Free Apache license:
Loading checkpoint shards: 100%|██████████████████| 3/3 [00:05<00:00, 1.91s/it]
model load successfully
Unsloth 2024.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.
max_steps is given, it will override any value given in num_train_epochs
==((====))== Unsloth - 2x faster free finetuning | Num GPUs = 1
\ /| Num examples = 87 | Num Epochs = 5
O^O/ _/ \ Batch size per device = 2 | Gradient Accumulation steps = 4
\ / Total batch size = 8 | Total steps = 50
-" Number of trainable parameters = 41,943,040
0%| | 0/50 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/aiserver/project_sibyl/CogVLM/", line 142, in
File "/home/aiserver/project_sibyl/CogVLM/", line 88, in train
trainer_stats = trainer.train()
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/trl/trainer/", line 361, in train
output = super().train(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/", line 1859, in train
return inner_training_loop(
File "", line 361, in _fast_inner_training_loop
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/", line 3138, in training_step
loss = self.compute_loss(model, inputs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/transformers/", line 3161, in compute_loss
outputs = model(**inputs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/accelerate/utils/", line 825, in forward
return model_forward(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/accelerate/utils/", line 813, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/amp/", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/unsloth/models/", line 882, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/torch/nn/modules/", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/aiserver/.conda/envs/emu/lib/python3.10/site-packages/peft/tuners/", line 161, in forward
return self.model.forward(*args, **kwargs)
TypeError: MistralForCausalLM.forward() got an unexpected keyword argument 'causal_mask'

And my python 3.10 dependencies:
accelerate 0.29.3
aiofiles 23.2.1
aiohttp 3.9.1
aioprometheus 23.12.0
aiosignal 1.3.1
altair 5.2.0
annotated-types 0.6.0
anyio 4.2.0
anykeystore 0.2
apex 0.1
asgiref 3.7.2
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.1.0
backoff 2.2.1
bcrypt 4.1.2
beautifulsoup4 4.12.3
bitsandbytes 0.41.3.post2
blessed 1.20.0
blinker 1.7.0
blis 0.7.11
boto3 1.34.10
botocore 1.34.10
braceexpand 0.1.7
build 1.1.1
cachetools 5.3.2
catalogue 2.0.10
certifi 2023.11.17
chardet 5.2.0
charset-normalizer 3.3.2
chroma-hnswlib 0.7.3
chromadb 0.4.24
click 8.1.7
cloudpathlib 0.16.0
colorama 0.4.6
coloredlogs 15.0.1
comm 0.2.1
confection 0.1.4
contourpy 1.2.0
cpm-kernels 1.0.11
cryptacular 1.6.2
cycler 0.11.0
cymem 2.0.8
dataclasses-json 0.6.4
datasets 2.18.0
debugpy 1.6.7
decorator 5.1.1
decord 0.6.0
deepspeed 0.12.6
defusedxml 0.7.1
Deprecated 1.2.14
diffusers 0.15.1
dill 0.3.7
distro 1.9.0
docstring_parser 0.16
einops 0.7.0
emoji 2.10.1
en-core-web-sm 3.7.1
et-xmlfile 1.1.0
exceptiongroup 1.2.0
executing 2.0.1
fastapi 0.108.0
ffmpy 0.3.1
filelock 3.13.1
filetype 1.2.0
flash-attn 2.5.8
Flask 3.0.2
flatbuffers 23.5.26
fonttools 4.25.0
frozenlist 1.4.1
fsspec 2023.10.0
gitdb 4.0.11
GitPython 3.1.40
google-auth 2.28.1
googleapis-common-protos 1.62.0
gpustat 1.1.1
gradio 4.12.0
gradio_client 0.8.0
greenlet 3.0.3
grpcio 1.62.0
h11 0.14.0
hjson 3.1.0
httpcore 1.0.2
httptools 0.6.1
httpx 0.26.0
huggingface-hub 0.20.1
humanfriendly 10.0
hupper 1.12
idna 3.6
importlib-metadata 6.11.0
importlib-resources 6.1.1
ipykernel 6.29.3
ipython 8.22.1
itsdangerous 2.1.2
jedi 0.19.1
Jinja2 3.1.2
jmespath 1.0.1
joblib 1.3.2
jsonlines 4.0.0
jsonpatch 1.33
jsonpath-python 1.0.6
jsonpointer 2.4
jsonschema 4.20.0
jsonschema-specifications 2023.12.1
jupyter_client 8.6.0
jupyter_core 5.7.1
kiwisolver 1.4.4
kubernetes 29.0.0
langchain 0.1.10
langchain-community 0.0.25
langchain-core 0.1.28
langchain-openai 0.0.8
langchain-text-splitters 0.0.1
langcodes 3.3.0
langdetect 1.0.9
langsmith 0.1.14
loguru 0.7.2
lxml 5.1.0
markdown-it-py 3.0.0
MarkupSafe 2.1.3
marshmallow 3.21.0
matplotlib 3.8.0
matplotlib-inline 0.1.6
mdurl 0.1.2
mkl-fft 1.3.8
mkl-random 1.2.4
mkl-service 2.4.0
mmh3 4.1.0
monotonic 1.6
mpmath 1.3.0
msgpack 1.0.7
multidict 6.0.4
multiprocess 0.70.15
munkres 1.1.4
murmurhash 1.0.10
mypy-extensions 1.0.0
nest_asyncio 1.6.0
networkx 3.2.1
nltk 3.8.1
numpy 1.26.2
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-ml-py 12.535.133
nvidia-nccl-cu12 2.18.1
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
onnxruntime 1.17.1
openai 1.13.3
openpyxl 3.1.2
opentelemetry-api 1.23.0
opentelemetry-exporter-otlp-proto-common 1.23.0
opentelemetry-exporter-otlp-proto-grpc 1.23.0
opentelemetry-instrumentation 0.44b0
opentelemetry-instrumentation-asgi 0.44b0
opentelemetry-instrumentation-fastapi 0.44b0
opentelemetry-proto 1.23.0
opentelemetry-sdk 1.23.0
opentelemetry-semantic-conventions 0.44b0
opentelemetry-util-http 0.44b0
orjson 3.9.15
overrides 7.7.0
packaging 23.2
pandas 2.1.4
parso 0.8.3
PasteDeploy 3.1.0
pbkdf2 1.3
pdfminer 20191125
peft 0.10.0
pexpect 4.9.0
pickleshare 0.7.5
Pillow 10.1.0
pip 23.3.1
plaster 1.1.2
plaster-pastedeploy 1.0.1
platformdirs 4.2.0
ply 3.11
posthog 3.4.2
preshed 3.0.9
prompt-toolkit 3.0.42
protobuf 3.20.3
psutil 5.9.0
ptyprocess 0.7.0
pulsar-client 3.4.0
pure-eval 0.2.2
py-cpuinfo 9.0.0
pyarrow 14.0.2
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pycryptodome 3.20.0
pydantic 2.5.3
pydantic_core 2.14.6
pydeck 0.8.1b0
pydub 0.25.1
Pygments 2.17.2
pynvml 11.5.0
pyparsing 3.0.9
pypdf 4.1.0
PyPika 0.48.9
pyproject_hooks 1.0.0
PyQt5 5.15.10
PyQt5-sip 12.13.0
pyramid 2.0.2
pyramid-mailer 0.15.1
pytesseract 0.3.10
python-dateutil 2.8.2
python-docx 1.1.0
python-dotenv 1.0.0
python-iso639 2024.2.7
python-magic 0.4.27
python-multipart 0.0.6
python3-openid 3.2.0
pytz 2023.3.post1
PyYAML 6.0.1
pyzmq 25.1.2
quantile-python 1.1
rapidfuzz 3.6.1
ray 2.9.0
referencing 0.32.0
regex 2023.12.25
repoze.sendmail 4.4.1
requests 2.31.0
requests-oauthlib 1.3.1
rich 13.7.0
rpds-py 0.16.2
rsa 4.9
s3transfer 0.10.0
safetensors 0.4.1
scipy 1.11.4
seaborn 0.13.0
semantic-version 2.10.0
sentencepiece 0.2.0
setuptools 68.2.2
shellingham 1.5.4
shtab 1.7.1
sip 6.7.12
six 1.16.0
smart-open 6.4.0
smmap 5.0.1
sniffio 1.3.0
soupsieve 2.5
spacy 3.7.2
spacy-legacy 3.0.12
spacy-loggers 1.0.5
SQLAlchemy 2.0.24
srsly 2.4.8
sse-starlette 1.8.2
stack-data 0.6.2
starlette 0.32.0.post1
streamlit 1.31.1
streamlit-chat 0.1.1
SwissArmyTransformer 0.4.9
sympy 1.12
tabulate 0.9.0
tenacity 8.2.3
thinc 8.2.2
tiktoken 0.6.0
timm 0.9.12
tokenizers 0.19.1
toml 0.10.2
tomli 2.0.1
tomlkit 0.12.0
toolz 0.12.0
torch 2.1.2
torchvision 0.16.2
tornado 6.3.3
tqdm 4.66.1
traitlets 5.14.1
transaction 4.0
transformers 4.40.1
translationstring 1.4
triton 2.1.0
trl 0.8.6
typer 0.9.0
typing_extensions 4.10.0
typing-inspect 0.9.0
tyro 0.8.3
tzdata 2023.3
tzlocal 5.2
unsloth 2024.4
unstructured 0.12.5
unstructured-client 0.21.0
urllib3 2.0.7
uvicorn 0.24.0.post1
uvloop 0.19.0
validators 0.22.0
velruse 1.1.1
venusian 3.1.0
vllm 0.2.6
wasabi 1.1.2
watchdog 3.0.0
watchfiles 0.21.0
wcwidth 0.2.12
weasel 0.3.4
webdataset 0.2.86
WebOb 1.8.7
websocket-client 1.7.0
websockets 11.0.3
Werkzeug 3.0.1
wheel 0.43.0
wrapt 1.16.0
WTForms 3.1.1
wtforms-recaptcha 0.3.2
xformers 0.0.22.post7
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0
zope.deprecation 5.0
zope.interface 6.1
zope.sqlalchemy 3.1


Hmm thats weird - do you know if Colab Mistral works?

