Commit
•
d155d07
1
Parent(s):
e28ebda
Fixed beam search error when using multiple GPUs (#6)
Browse files- Fixed beam search error when using multiple GPUs (14a911cdcbe50c9fbaae49af6f26b3d981b30cde)
Co-authored-by: Hiroki Yamaguchi <[email protected]>
- modeling_plamo.py +1 -1
modeling_plamo.py
CHANGED
@@ -701,5 +701,5 @@ class PlamoForCausalLM(PlamoPreTrainedModel):
|
|
701 |
def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
|
702 |
reordered_past: Tuple[Any, ...] = ()
|
703 |
for layer_past in past_key_values:
|
704 |
-
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
705 |
return reordered_past
|
|
|
701 |
def _reorder_cache(past_key_values: List[torch.FloatTensor], beam_idx: int) -> Tuple[Any, ...]:
|
702 |
reordered_past: Tuple[Any, ...] = ()
|
703 |
for layer_past in past_key_values:
|
704 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),)
|
705 |
return reordered_past
|