Update modeling_chatglm.py for inputs_embeds
#45
by
Xipotzzz
- opened
- modeling_chatglm.py +22 -11
modeling_chatglm.py
CHANGED
@@ -914,11 +914,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
914 |
use_cache = False
|
915 |
|
916 |
if input_ids is not None and inputs_embeds is not None:
|
917 |
-
|
918 |
-
|
919 |
batch_size, seq_length = input_ids.shape[:2]
|
920 |
elif inputs_embeds is not None:
|
921 |
-
batch_size, seq_length
|
922 |
else:
|
923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
924 |
|
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
972 |
|
973 |
if attention_mask is None:
|
974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
975 |
-
|
976 |
else:
|
977 |
-
attention_mask = attention_mask.to(
|
978 |
|
979 |
for i, layer in enumerate(self.layers):
|
980 |
|
@@ -1105,6 +1104,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1105 |
def prepare_inputs_for_generation(
|
1106 |
self,
|
1107 |
input_ids: torch.LongTensor,
|
|
|
1108 |
past: Optional[torch.Tensor] = None,
|
1109 |
past_key_values: Optional[torch.Tensor] = None,
|
1110 |
attention_mask: Optional[torch.Tensor] = None,
|
@@ -1165,12 +1165,23 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1165 |
use_gmasks=use_gmasks
|
1166 |
)
|
1167 |
|
1168 |
-
|
1169 |
-
|
1170 |
-
|
1171 |
-
"
|
1172 |
-
|
1173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1174 |
|
1175 |
def forward(
|
1176 |
self,
|
|
|
914 |
use_cache = False
|
915 |
|
916 |
if input_ids is not None and inputs_embeds is not None:
|
917 |
+
logger.warning("You passed both `inputs_embeds` and `input_ids`. Will use `inputs_embeds`")
|
918 |
+
if input_ids is not None:
|
919 |
batch_size, seq_length = input_ids.shape[:2]
|
920 |
elif inputs_embeds is not None:
|
921 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
922 |
else:
|
923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
924 |
|
|
|
972 |
|
973 |
if attention_mask is None:
|
974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
975 |
else:
|
976 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
977 |
|
978 |
for i, layer in enumerate(self.layers):
|
979 |
|
|
|
1104 |
def prepare_inputs_for_generation(
|
1105 |
self,
|
1106 |
input_ids: torch.LongTensor,
|
1107 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1108 |
past: Optional[torch.Tensor] = None,
|
1109 |
past_key_values: Optional[torch.Tensor] = None,
|
1110 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
1165 |
use_gmasks=use_gmasks
|
1166 |
)
|
1167 |
|
1168 |
+
if inputs_embeds is not None:
|
1169 |
+
assert input_ids.size(1) == inputs_embeds.size(
|
1170 |
+
1
|
1171 |
+
), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
|
1172 |
+
return {
|
1173 |
+
"inputs_embeds": inputs_embeds,
|
1174 |
+
"past_key_values": past,
|
1175 |
+
"position_ids": position_ids,
|
1176 |
+
"attention_mask": attention_mask,
|
1177 |
+
}
|
1178 |
+
else:
|
1179 |
+
return {
|
1180 |
+
"input_ids": input_ids,
|
1181 |
+
"past_key_values": past,
|
1182 |
+
"position_ids": position_ids,
|
1183 |
+
"attention_mask": attention_mask,
|
1184 |
+
}
|
1185 |
|
1186 |
def forward(
|
1187 |
self,
|