Fixed beam search error when using multiple GPUs
Browse files- modeling_plamo.py +1 -1
modeling_plamo.py
CHANGED
@@ -701,5 +701,5 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
|
|
701 |
def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
|
702 |
reordered_past: Tuple[Any, ...] = ()
|
703 |
for layer_past in past_key_values:
|
704 |
-
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
705 |
return reordered_past
|
|
|
701 |
def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
|
702 |
reordered_past: Tuple[Any, ...] = ()
|
703 |
for layer_past in past_key_values:
|
704 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
705 |
return reordered_past
|