Stanislas davinwang commited on
Commit
18f9ab1
1 Parent(s): f3822a7

compatible with DirectML/ROCm (#5)

Browse files

- compatible with DirectML/ROCm (5bc5aff72b4d8fbdb10f7befa1473a741ecec8b5)


Co-authored-by: Davin Wang <[email protected]>

Files changed (1) hide show
  1. modeling_chatglm.py +2 -1
modeling_chatglm.py CHANGED
@@ -16,6 +16,7 @@ from transformers.modeling_outputs import (
16
  BaseModelOutputWithPast,
17
  CausalLMOutputWithPast,
18
  )
 
19
  from transformers.modeling_utils import PreTrainedModel
20
  from transformers.utils import logging
21
  from transformers.generation.logits_process import LogitsProcessor
@@ -1138,7 +1139,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1138
  )
1139
  logits_warper = self._get_logits_warper(generation_config)
1140
 
1141
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1142
  scores = None
1143
  while True:
1144
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
16
  BaseModelOutputWithPast,
17
  CausalLMOutputWithPast,
18
  )
19
+
20
  from transformers.modeling_utils import PreTrainedModel
21
  from transformers.utils import logging
22
  from transformers.generation.logits_process import LogitsProcessor
 
1139
  )
1140
  logits_warper = self._get_logits_warper(generation_config)
1141
 
1142
+ unfinished_sequences = torch.ones(input_ids.shape[0], device=input_ids.device, dtype=input_ids.dtype)
1143
  scores = None
1144
  while True:
1145
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)