hibikaze commited on
Commit
14a911c
1 Parent(s): e28ebda

Fixed beam search error when using multiple GPUs

Browse files
Files changed (1) hide show
  1. modeling_plamo.py +1 -1
modeling_plamo.py CHANGED
@@ -701,5 +701,5 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
701
  def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
702
  reordered_past: Tuple[Any, ...] = ()
703
  for layer_past in past_key_values:
704
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
705
  return reordered_past
 
701
  def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
702
  reordered_past: Tuple[Any, ...] = ()
703
  for layer_past in past_key_values:
704
+ reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
705
  return reordered_past