Beam search fails when running on multiple GPUs

#5
by hibikaze - opened

We performed inference under the following conditions.
・transformers==4.33.1
・Use two GPUs (with device_map="auto")
・8bit quantized
・Use peft

When using beam search as the decoding strategy, the following error occurred.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 93
     91 with torch.cuda.amp.autocast():
     92     with torch.no_grad():
---> 93         output_tokens = peft_model.generate(
     94             **batch,
     95             max_new_tokens=max_new_token_len,
     96             do_sample=False,
     97             num_beams=2,
     98             #num_beam_groups=2,
     99             #no_repeat_ngram_size=2,
    100             #early_stopping=True,
    101             #penalty_alpha=0.6,
    102             #temperature=1.0,
    103             #top_k=4,
    104             #top_p=0.95,
    105             #diversity_penalty=1.0,
    106             #repetition_penalty=1.0,
    107             #bad_words_ids=,
    108             #force_words_ids=,
    109             #constraints=,
    110             pad_token_id=tokenizer.pad_token_id,
    111             eos_token_id=tokenizer.eos_token_id,
    112         )
    114 decoded_output = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    115 print(decoded_output)

File /usr/local/lib/python3.10/dist-packages/peft/peft_model.py:975, in generate(self, **kwargs)
    969     warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
    970     kwargs["token_type_ids"] = None
    971 kwargs.update(
    972     {
    973         "attention_mask": attention_mask,
    974         "labels": labels,
--> 975         "output_attentions": output_attentions,
    976         "output_hidden_states": output_hidden_states,
    977         "return_dict": return_dict,
    978     }
    979 )
    981 if peft_config.peft_type == PeftType.PREFIX_TUNING:
    982     past_key_values = self.get_prompt(batch_size)

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1681, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1674     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1675         input_ids=input_ids,
   1676         expand_size=generation_config.num_beams,
   1677         is_encoder_decoder=self.config.is_encoder_decoder,
   1678         **model_kwargs,
   1679     )
   1680     # 13. run beam search
-> 1681     return self.beam_search(
   1682         input_ids,
   1683         beam_scorer,
   1684         logits_processor=logits_processor,
   1685         stopping_criteria=stopping_criteria,
   1686         pad_token_id=generation_config.pad_token_id,
   1687         eos_token_id=generation_config.eos_token_id,
   1688         output_scores=generation_config.output_scores,
   1689         return_dict_in_generate=generation_config.return_dict_in_generate,
   1690         synced_gpus=synced_gpus,
   1691         **model_kwargs,
   1692     )
   1694 elif generation_mode == GenerationMode.BEAM_SAMPLE:
   1695     # 11. prepare logits warper
   1696     logits_warper = self._get_logits_warper(generation_config)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:3091, in GenerationMixin.beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   3087 model_kwargs = self._update_model_kwargs_for_generation(
   3088     outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
   3089 )
   3090 if model_kwargs["past_key_values"] is not None:
-> 3091     model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
   3093 if return_dict_in_generate and output_scores:
   3094     beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

File ~/.cache/huggingface/modules/transformers_modules/pfnet/plamo-13b/e28ebda68a728f36f9279afe14cb68e94ac95eff/modeling_plamo.py:704, in PlamoForCausalLM._reorder_cache(past_key_values, beam_idx)
    702 reordered_past: Tuple[Any, ...] = ()
    703 for layer_past in past_key_values:
--> 704     reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
    705 return reordered_past

File ~/.cache/huggingface/modules/transformers_modules/pfnet/plamo-13b/e28ebda68a728f36f9279afe14cb68e94ac95eff/modeling_plamo.py:704, in <genexpr>(.0)
    702 reordered_past: Tuple[Any, ...] = ()
    703 for layer_past in past_key_values:
--> 704     reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
    705 return reordered_past

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

I was able to prevent the error by modifying "modeling_plamo.py" as follows. (The repository below was created for verification purposes and will be deleted as soon as it is no longer needed)
https://huggingface.co/hibikaze/change-modeling-plamo-13b/commit/1435eef93e8ed93a3faa1592080bdf0c08765933

(reference)
https://github.com/huggingface/transformers/blob/2d8ee9817c0ad750b37e7fefef692a5c473b5770/src/transformers/models/opt/modeling_opt.py#L1007

Thank you for your suggestion! Your proposed change seems beneficial. If you're interested, would you be willing to submit a PR? If not, don't worry, we can implement the change on our end. Either way, we appreciate your input!

Thank you for the confirmation.
I've submitted a pull request.

dhigurashi changed discussion status to closed

Sign up or log in to comment