RuntimeError: shape '[18, 217, 32, 128]' is invalid for input of size 3999744

#151
by kerkathy - opened

Hi, I'm encountering a weird size issue here during forward() when I'm running a script on a CUDA 11.3 / 12.0 machine. However the same script run on a CUDA 12.2 machine works without any problem.
Not sure if it's bitsandbytes or llama model that cause the issue. Or maybe I have some packages in wrong version?
A similar issue was also raised in #133
Any suggestion would be appreciated.
In brief:

/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 
195, in forward
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[18, 217, 32, 128]' is invalid for input of size 3999744  

I loaded the model like this:

model_name= "meta-llama/Meta-Llama-3-8B-Instruct"
    from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
    tokenizer = load_lm_tokenizer(model_name)
    config = AutoConfig.from_pretrained(model_name)
        from transformers import BitsAndBytesConfig
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", cache_dir=cache_dir, quantization_config=quantization_config)

Full error traceback:

  File "/home/research/utils/lm_utils.py", line 156, in get_lm_prob                                               
    outputs_batch = model(input_ids=input_ids_batch, attention_mask=attention_mask_batch).logits # [ext_batch_size, seq_len, vocab_size]     
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                  
    return forward_call(*input, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward                     
    outputs = self.model(  
  File "/home/guest/r11944026/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in [18/1884]l                                  
    return forward_call(*input, **kwargs)                             
  File "/home/guest/r11944026/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)                             
  File "/home/guest/r11944026/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 688, in forward                     
    outputs = self.model(                                             
  File "/home/guest/r11944026/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                  
    return forward_call(*input, **kwargs)                             
  File "/homeanaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 578, in forward                     
    layer_outputs = decoder_layer(                                    
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                  
    return forward_call(*input, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 292, in forward                     
    hidden_states, self_attn_weights, present_key_value = self.self_attn(                                                                    
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                  
    return forward_call(*input, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)                             
  File "/home/anaconda3/envs/train_dpr_cu113/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 195, in forward                     
    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[18, 217, 32, 128]' is invalid for input of size 3999744 

My environment (on CUDA 12.0, 11.4 both don't work) :

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
accelerate                0.30.1                   pypi_0    pypi
bitsandbytes              0.42.0                   pypi_0    pypi
blas                      1.0                         mkl  
brotli-python             1.0.9            py39h6a678d5_8  
bzip2                     1.0.8                h5eee18b_6  
ca-certificates           2024.3.11            h06a4308_0  
certifi                   2024.2.2         py39h06a4308_0  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
cuda-cudart               11.8.89                       0    nvidia
cuda-cupti                11.8.87                       0    nvidia
cuda-libraries            11.8.0                        0    nvidia
cuda-nvrtc                11.8.89                       0    nvidia
cuda-nvtx                 11.8.86                       0    nvidia
cuda-runtime              11.8.0                        0    nvidia
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.13.1           py39h06a4308_0  
freetype                  2.12.1               h4a9f257_0  
fsspec                    2024.5.0                 pypi_0    pypi
gmp                       6.2.1                h295c915_3  
gmpy2                     2.1.2            py39heeb90bb_0  
gnutls                    3.6.15               he1e5248_0  
huggingface-hub           0.23.1                   pypi_0    pypi
idna                      3.7              py39h06a4308_0  
intel-openmp              2023.1.0         hdb19cb5_46306  
jinja2                    3.1.3            py39h06a4308_0  
jpeg                      9e                   h5eee18b_1  
lame                      3.100                h7b6447c_0  
lcms2                     2.12                 h3be6417_0  
ld_impl_linux-64          2.38                 h1181459_1  
lerc                      3.0                  h295c915_0  
libabseil                 20240116.2      cxx17_h59595ed_0    conda-forge
libcublas                 11.11.3.6                     0    nvidia
libcufft                  10.9.0.58                     0    nvidia
libcufile                 1.9.1.3                       0    nvidia
libcurand                 10.3.5.147                    0    nvidia
libcusolver               11.4.1.48                     0    nvidia
libcusparse               11.7.5.86                     0    nvidia
libdeflate                1.17                 h5eee18b_1  
libffi                    3.4.4                h6a678d5_1  
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libiconv                  1.16                 h5eee18b_3  
libidn2                   2.3.4                h5eee18b_0  
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
libnpp                    11.8.0.86                     0    nvidia
libnvjpeg                 11.9.0.86                     0    nvidia
libpng                    1.6.39               h5eee18b_0  
libprotobuf               4.25.3               h08a7969_0    conda-forge
libsentencepiece          0.2.0                hb0b37bd_1    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtasn1                  4.19.0               h5eee18b_0  
libtiff                   4.5.1                h6a678d5_0  
libunistring              0.9.10               h27cfd23_0  
libwebp-base              1.3.2                h5eee18b_0  
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               14.0.6               h9e868ea_0  
lz4-c                     1.9.4                h6a678d5_1  
markupsafe                2.1.3            py39h5eee18b_0  
mkl                       2023.1.0         h213fc3f_46344  
mkl-service               2.4.0            py39h5eee18b_1  
mkl_fft                   1.3.8            py39h5eee18b_0  
mkl_random                1.2.4            py39hdb19cb5_0  
mpc                       1.1.0                h10f8cd9_1  
mpfr                      4.0.2                hb69a4c5_1  
mpmath                    1.3.0            py39h06a4308_0  
ncurses                   6.4                  h6a678d5_0  
nettle                    3.7.3                hbbd107a_1  
networkx                  3.1              py39h06a4308_0  
numpy                     1.26.4           py39h5f9d8c6_0  
numpy-base                1.26.4           py39hb5e798b_0  
openh264                  2.1.1                h4ff587b_0  
openjpeg                  2.4.0                h3ad879b_0  
openssl                   3.0.13               h7f8727e_2  
packaging                 24.0                     pypi_0    pypi
pillow                    10.3.0           py39h5eee18b_0  
pip                       24.0             py39h06a4308_0  
psutil                    5.9.8                    pypi_0    pypi
pysocks                   1.7.1            py39h06a4308_0  
python                    3.9.19               h955ad1f_1  
python_abi                3.9                      2_cp39    conda-forge
pytorch                   2.2.2           py3.9_cuda11.8_cudnn8.7.0_0    pytorch
pytorch-cuda              11.8                 h7e8668a_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0.1            py39h5eee18b_0  
readline                  8.2                  h5eee18b_0  
regex                     2024.5.15                pypi_0    pypi
requests                  2.31.0           py39h06a4308_1  
safetensors               0.4.3                    pypi_0    pypi
scipy                     1.13.0                   pypi_0    pypi
sentencepiece             0.2.0                hf3d152e_1    conda-forge
sentencepiece-python      0.2.0            py39ha537242_1    conda-forge
sentencepiece-spm         0.2.0                hb0b37bd_1    conda-forge
setuptools                69.5.1           py39h06a4308_0  
sqlite                    3.45.3               h5eee18b_0  
sympy                     1.12             py39h06a4308_0  
tbb                       2021.8.0             hdb19cb5_0  
tk                        8.6.14               h39e8969_0  
tokenizers                0.19.1                   pypi_0    pypi
torchaudio                2.2.2                py39_cu118    pytorch
torchtriton               2.2.0                      py39    pytorch
torchvision               0.17.2               py39_cu118    pytorch
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.41.0                   pypi_0    pypi
typing_extensions         4.11.0           py39h06a4308_0  
tzdata                    2024a                h04d1e81_0  
urllib3                   2.2.1            py39h06a4308_0  
wheel                     0.43.0           py39h06a4308_0  
xz                        5.4.6                h5eee18b_1  
yaml                      0.2.5                h7b6447c_0  
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.5                hc292b87_2 

Sign up or log in to comment