Support dynamic ntk rope
Browse files- modeling_internlm.py +163 -73
modeling_internlm.py
CHANGED
@@ -19,26 +19,36 @@
|
|
19 |
# limitations under the License.
|
20 |
""" PyTorch InternLM model."""
|
21 |
import math
|
|
|
|
|
22 |
from typing import List, Optional, Tuple, Union
|
23 |
-
import threading, queue
|
24 |
|
25 |
import torch
|
26 |
import torch.utils.checkpoint
|
27 |
from torch import nn
|
28 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
29 |
-
|
30 |
from transformers.activations import ACT2FN
|
31 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
32 |
-
from transformers.modeling_utils import PreTrainedModel
|
33 |
from transformers.generation.streamers import BaseStreamer
|
34 |
-
from transformers.
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
|
|
37 |
|
38 |
logger = logging.get_logger(__name__)
|
39 |
|
40 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
41 |
|
|
|
42 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
43 |
def _make_causal_mask(
|
44 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
@@ -73,6 +83,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
73 |
|
74 |
|
75 |
class InternLMRMSNorm(nn.Module):
|
|
|
|
|
76 |
def __init__(self, hidden_size, eps=1e-6):
|
77 |
"""
|
78 |
InternLMRMSNorm is equivalent to T5LayerNorm
|
@@ -93,6 +105,14 @@ class InternLMRMSNorm(nn.Module):
|
|
93 |
|
94 |
|
95 |
class InternLMRotaryEmbedding(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
97 |
super().__init__()
|
98 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
@@ -124,6 +144,66 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
|
124 |
)
|
125 |
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def rotate_half(x):
|
128 |
"""Rotates half the hidden dims of the input."""
|
129 |
x1 = x[..., : x.shape[-1] // 2]
|
@@ -135,10 +215,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|
135 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
136 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
137 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
138 |
-
cos = cos
|
139 |
-
sin = sin
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
return q_embed, k_embed
|
143 |
|
144 |
|
@@ -179,7 +267,25 @@ class InternLMAttention(nn.Module):
|
|
179 |
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
180 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
181 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
182 |
-
self.rotary_emb =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
185 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
@@ -199,20 +305,18 @@ class InternLMAttention(nn.Module):
|
|
199 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
200 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
201 |
|
202 |
-
kv_seq_len = key_states.shape[-2]
|
203 |
-
if past_key_value is not None:
|
204 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
205 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
206 |
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
207 |
-
# [bsz, nh, t, hd]
|
208 |
-
|
209 |
if past_key_value is not None:
|
210 |
# reuse k, v, self_attention
|
211 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
212 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
213 |
|
|
|
214 |
past_key_value = (key_states, value_states) if use_cache else None
|
215 |
|
|
|
|
|
|
|
|
|
216 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
217 |
|
218 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
@@ -322,11 +426,9 @@ INTERNLM_START_DOCSTRING = r"""
|
|
322 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
323 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
324 |
etc.)
|
325 |
-
|
326 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
327 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
328 |
and behavior.
|
329 |
-
|
330 |
Parameters:
|
331 |
config ([`InternLMConfig`]):
|
332 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
@@ -367,44 +469,34 @@ INTERNLM_INPUTS_DOCSTRING = r"""
|
|
367 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
368 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
369 |
it.
|
370 |
-
|
371 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
372 |
[`PreTrainedTokenizer.__call__`] for details.
|
373 |
-
|
374 |
[What are input IDs?](../glossary#input-ids)
|
375 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
376 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
377 |
-
|
378 |
- 1 for tokens that are **not masked**,
|
379 |
- 0 for tokens that are **masked**.
|
380 |
-
|
381 |
[What are attention masks?](../glossary#attention-mask)
|
382 |
-
|
383 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
384 |
[`PreTrainedTokenizer.__call__`] for details.
|
385 |
-
|
386 |
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
387 |
`past_key_values`).
|
388 |
-
|
389 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
390 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
391 |
information on the default strategy.
|
392 |
-
|
393 |
- 1 indicates the head is **not masked**,
|
394 |
- 0 indicates the head is **masked**.
|
395 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
396 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
397 |
config.n_positions - 1]`.
|
398 |
-
|
399 |
[What are position IDs?](../glossary#position-ids)
|
400 |
-
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
|
|
|
401 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
402 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
403 |
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
404 |
-
|
405 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
406 |
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
407 |
-
|
408 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
409 |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
410 |
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
@@ -433,10 +525,10 @@ INTERNLM_INPUTS_DOCSTRING = r"""
|
|
433 |
class InternLMModel(InternLMPreTrainedModel):
|
434 |
"""
|
435 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
|
436 |
-
|
437 |
Args:
|
438 |
config: InternLMConfig
|
439 |
"""
|
|
|
440 |
_auto_class = "AutoModel"
|
441 |
|
442 |
def __init__(self, config: InternLMConfig):
|
@@ -662,20 +754,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
662 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
663 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
664 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
665 |
-
|
666 |
Returns:
|
667 |
-
|
668 |
Example:
|
669 |
-
|
670 |
```python
|
671 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
672 |
-
|
673 |
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
674 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
675 |
-
|
676 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
677 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
678 |
-
|
679 |
>>> # Generate
|
680 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
681 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
@@ -765,50 +851,56 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
765 |
for layer_past in past_key_values:
|
766 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
767 |
return reordered_past
|
768 |
-
|
769 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
770 |
prompt = ""
|
771 |
for record in history:
|
772 |
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
773 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
774 |
return tokenizer([prompt], return_tensors="pt")
|
775 |
-
|
776 |
@torch.no_grad()
|
777 |
-
def chat(
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
|
|
|
|
787 |
inputs = self.build_inputs(tokenizer, query, history)
|
788 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
789 |
-
outputs = self.generate(
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
|
|
|
|
797 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
798 |
response = response.split("<eoa>")[0]
|
799 |
history = history + [(query, response)]
|
800 |
return response, history
|
801 |
-
|
802 |
@torch.no_grad()
|
803 |
-
def stream_chat(
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
|
|
|
|
812 |
"""
|
813 |
Return a generator in format: (response, history)
|
814 |
Eg.
|
@@ -854,12 +946,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
854 |
tokenizer=tokenizer,
|
855 |
query=query,
|
856 |
streamer=ChatStreamer(tokenizer=tokenizer),
|
857 |
-
history=history,
|
858 |
max_new_tokens=max_new_tokens,
|
859 |
do_sample=do_sample,
|
860 |
temperature=temperature,
|
861 |
top_p=top_p,
|
862 |
-
**kwargs
|
863 |
)
|
864 |
|
865 |
def consumer():
|
@@ -877,10 +969,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|
877 |
@add_start_docstrings(
|
878 |
"""
|
879 |
The InternLM Model transformer with a sequence classification head on top (linear layer).
|
880 |
-
|
881 |
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
882 |
(e.g. GPT-2) do.
|
883 |
-
|
884 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
885 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
886 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
|
19 |
# limitations under the License.
|
20 |
""" PyTorch InternLM model."""
|
21 |
import math
|
22 |
+
import queue
|
23 |
+
import threading
|
24 |
from typing import List, Optional, Tuple, Union
|
|
|
25 |
|
26 |
import torch
|
27 |
import torch.utils.checkpoint
|
28 |
from torch import nn
|
29 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
30 |
from transformers.activations import ACT2FN
|
|
|
|
|
31 |
from transformers.generation.streamers import BaseStreamer
|
32 |
+
from transformers.modeling_outputs import (
|
33 |
+
BaseModelOutputWithPast,
|
34 |
+
CausalLMOutputWithPast,
|
35 |
+
SequenceClassifierOutputWithPast,
|
36 |
+
)
|
37 |
+
from transformers.modeling_utils import PreTrainedModel
|
38 |
+
from transformers.utils import (
|
39 |
+
add_start_docstrings,
|
40 |
+
add_start_docstrings_to_model_forward,
|
41 |
+
logging,
|
42 |
+
replace_return_docstrings,
|
43 |
+
)
|
44 |
|
45 |
+
from .configuration_internlm import InternLMConfig
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
49 |
_CONFIG_FOR_DOC = "InternLMConfig"
|
50 |
|
51 |
+
|
52 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
53 |
def _make_causal_mask(
|
54 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
83 |
|
84 |
|
85 |
class InternLMRMSNorm(nn.Module):
|
86 |
+
"""RMSNorm implemention."""
|
87 |
+
|
88 |
def __init__(self, hidden_size, eps=1e-6):
|
89 |
"""
|
90 |
InternLMRMSNorm is equivalent to T5LayerNorm
|
|
|
105 |
|
106 |
|
107 |
class InternLMRotaryEmbedding(torch.nn.Module):
|
108 |
+
"""Implement InternLM's rotary embedding.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
dim (int): Characteristic dimension of each self-attentional head.
|
112 |
+
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
|
113 |
+
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
114 |
+
device (Any, optional): Running device. Defaults to None.
|
115 |
+
"""
|
116 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
117 |
super().__init__()
|
118 |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
|
144 |
)
|
145 |
|
146 |
|
147 |
+
class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
148 |
+
"""Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
dim (int): Characteristic dimension of each self-attentional head.
|
152 |
+
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
|
153 |
+
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
154 |
+
device (Any, optional): Running device. Defaults to None.
|
155 |
+
scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
159 |
+
super().__init__()
|
160 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
161 |
+
self.register_buffer("inv_freq", inv_freq)
|
162 |
+
self.dim = dim
|
163 |
+
self.base = base
|
164 |
+
self.scaling_factor = scaling_factor
|
165 |
+
|
166 |
+
# Build here to make `torch.jit.trace` work.
|
167 |
+
self.max_position_embeddings = max_position_embeddings
|
168 |
+
self.max_seq_len_cached = max_position_embeddings
|
169 |
+
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
170 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
171 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
172 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
173 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
174 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
175 |
+
|
176 |
+
def _update_cached(self, x, seq_len=None):
|
177 |
+
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
|
178 |
+
if seq_len > self.max_position_embeddings:
|
179 |
+
base = self.base * (
|
180 |
+
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
181 |
+
) ** (self.dim / (self.dim - 2))
|
182 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
|
183 |
+
else:
|
184 |
+
inv_freq = self.inv_freq
|
185 |
+
t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
|
186 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
187 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
188 |
+
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
189 |
+
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
190 |
+
|
191 |
+
def forward(self, x, seq_len=None):
|
192 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
193 |
+
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
194 |
+
if seq_len <= self.max_position_embeddings:
|
195 |
+
# Reset the tables if the sequence length has changed,
|
196 |
+
if self.max_seq_len_cached > self.max_position_embeddings:
|
197 |
+
self._update_cached(x, seq_len)
|
198 |
+
else:
|
199 |
+
self._update_cached(x, seq_len)
|
200 |
+
|
201 |
+
return (
|
202 |
+
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
203 |
+
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
def rotate_half(x):
|
208 |
"""Rotates half the hidden dims of the input."""
|
209 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
215 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
216 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
217 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
218 |
+
cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
219 |
+
sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
220 |
+
if q.size(2) == 1:
|
221 |
+
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
|
222 |
+
else:
|
223 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
224 |
+
|
225 |
+
if k.size(2) == 1:
|
226 |
+
k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
|
227 |
+
else:
|
228 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
229 |
+
|
230 |
return q_embed, k_embed
|
231 |
|
232 |
|
|
|
267 |
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
268 |
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
269 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
270 |
+
self.rotary_emb = self._init_rope()
|
271 |
+
|
272 |
+
def _init_rope(self):
|
273 |
+
if self.config.rotary["type"] == "origin":
|
274 |
+
self.rotary_emb = InternLMRotaryEmbedding(
|
275 |
+
self.head_dim,
|
276 |
+
max_position_embeddings=self.max_position_embeddings,
|
277 |
+
base=self.config.rotary["base"],
|
278 |
+
)
|
279 |
+
elif self.config.rotary["type"] == "dynamic":
|
280 |
+
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
281 |
+
self.head_dim,
|
282 |
+
max_position_embeddings=self.max_position_embeddings,
|
283 |
+
base=self.config.rotary["base"],
|
284 |
+
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
|
288 |
+
return self.rotary_emb
|
289 |
|
290 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
291 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
305 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
306 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
if past_key_value is not None:
|
309 |
# reuse k, v, self_attention
|
310 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
311 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
312 |
|
313 |
+
# print(use_cache)
|
314 |
past_key_value = (key_states, value_states) if use_cache else None
|
315 |
|
316 |
+
kv_seq_len = key_states.shape[-2]
|
317 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
318 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
319 |
+
|
320 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
321 |
|
322 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
426 |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
427 |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
428 |
etc.)
|
|
|
429 |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
430 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
431 |
and behavior.
|
|
|
432 |
Parameters:
|
433 |
config ([`InternLMConfig`]):
|
434 |
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
|
469 |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
470 |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
471 |
it.
|
|
|
472 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
473 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
474 |
[What are input IDs?](../glossary#input-ids)
|
475 |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
476 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
477 |
- 1 for tokens that are **not masked**,
|
478 |
- 0 for tokens that are **masked**.
|
|
|
479 |
[What are attention masks?](../glossary#attention-mask)
|
|
|
480 |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
481 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
482 |
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
483 |
`past_key_values`).
|
|
|
484 |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
485 |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
486 |
information on the default strategy.
|
|
|
487 |
- 1 indicates the head is **not masked**,
|
488 |
- 0 indicates the head is **masked**.
|
489 |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
490 |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
491 |
config.n_positions - 1]`.
|
|
|
492 |
[What are position IDs?](../glossary#position-ids)
|
493 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
|
494 |
+
when `config.use_cache=True`):
|
495 |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
496 |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
497 |
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
|
498 |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
499 |
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
|
500 |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
501 |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
502 |
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
|
525 |
class InternLMModel(InternLMPreTrainedModel):
|
526 |
"""
|
527 |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
|
|
|
528 |
Args:
|
529 |
config: InternLMConfig
|
530 |
"""
|
531 |
+
|
532 |
_auto_class = "AutoModel"
|
533 |
|
534 |
def __init__(self, config: InternLMConfig):
|
|
|
754 |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
755 |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
756 |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
757 |
Returns:
|
|
|
758 |
Example:
|
|
|
759 |
```python
|
760 |
>>> from transformers import AutoTokenizer, InternLMForCausalLM
|
|
|
761 |
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
762 |
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
|
763 |
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
764 |
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
765 |
>>> # Generate
|
766 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
767 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
|
851 |
for layer_past in past_key_values:
|
852 |
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
853 |
return reordered_past
|
854 |
+
|
855 |
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
|
856 |
prompt = ""
|
857 |
for record in history:
|
858 |
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
|
859 |
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
|
860 |
return tokenizer([prompt], return_tensors="pt")
|
861 |
+
|
862 |
@torch.no_grad()
|
863 |
+
def chat(
|
864 |
+
self,
|
865 |
+
tokenizer,
|
866 |
+
query: str,
|
867 |
+
history: List[Tuple[str, str]] = [],
|
868 |
+
streamer: Optional[BaseStreamer] = None,
|
869 |
+
max_new_tokens: int = 1024,
|
870 |
+
do_sample: bool = True,
|
871 |
+
temperature: float = 0.8,
|
872 |
+
top_p: float = 0.8,
|
873 |
+
**kwargs,
|
874 |
+
):
|
875 |
inputs = self.build_inputs(tokenizer, query, history)
|
876 |
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
|
877 |
+
outputs = self.generate(
|
878 |
+
**inputs,
|
879 |
+
streamer=streamer,
|
880 |
+
max_new_tokens=max_new_tokens,
|
881 |
+
do_sample=do_sample,
|
882 |
+
temperature=temperature,
|
883 |
+
top_p=top_p,
|
884 |
+
**kwargs,
|
885 |
+
)
|
886 |
+
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
|
887 |
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
888 |
response = response.split("<eoa>")[0]
|
889 |
history = history + [(query, response)]
|
890 |
return response, history
|
891 |
+
|
892 |
@torch.no_grad()
|
893 |
+
def stream_chat(
|
894 |
+
self,
|
895 |
+
tokenizer,
|
896 |
+
query: str,
|
897 |
+
history: List[Tuple[str, str]] = [],
|
898 |
+
max_new_tokens: int = 1024,
|
899 |
+
do_sample: bool = True,
|
900 |
+
temperature: float = 0.8,
|
901 |
+
top_p: float = 0.8,
|
902 |
+
**kwargs,
|
903 |
+
):
|
904 |
"""
|
905 |
Return a generator in format: (response, history)
|
906 |
Eg.
|
|
|
946 |
tokenizer=tokenizer,
|
947 |
query=query,
|
948 |
streamer=ChatStreamer(tokenizer=tokenizer),
|
949 |
+
history=history,
|
950 |
max_new_tokens=max_new_tokens,
|
951 |
do_sample=do_sample,
|
952 |
temperature=temperature,
|
953 |
top_p=top_p,
|
954 |
+
**kwargs,
|
955 |
)
|
956 |
|
957 |
def consumer():
|
|
|
969 |
@add_start_docstrings(
|
970 |
"""
|
971 |
The InternLM Model transformer with a sequence classification head on top (linear layer).
|
|
|
972 |
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
973 |
(e.g. GPT-2) do.
|
|
|
974 |
Since it does classification on the last token, it requires to know the position of the last token. If a
|
975 |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
976 |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|