LeroyDyer commited on
Commit
384c683
1 Parent(s): 86957a5

Delete convert_mistral_weights_to_hf.py

Browse files
Files changed (1) hide show
  1. convert_mistral_weights_to_hf.py +0 -276
convert_mistral_weights_to_hf.py DELETED
@@ -1,276 +0,0 @@
1
- # Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import argparse
15
- import gc
16
- import json
17
- import os
18
- import shutil
19
- import warnings
20
-
21
- import torch
22
-
23
- from transformers import (
24
- LlamaTokenizer,
25
- MistralConfig,
26
- MistralForCausalLM,
27
- )
28
-
29
-
30
- try:
31
- from transformers import LlamaTokenizerFast
32
-
33
- tokenizer_class = LlamaTokenizerFast
34
- except ImportError as e:
35
- warnings.warn(e)
36
- warnings.warn(
37
- "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
38
- )
39
- tokenizer_class = LlamaTokenizer
40
-
41
- """
42
- Sample usage:
43
-
44
- ```
45
- python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \
46
- --input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path
47
- ```
48
-
49
- Thereafter, models can be loaded via:
50
-
51
- ```py
52
- from transformers import MistralForCausalLM, LlamaTokenizer
53
-
54
- model = MistralForCausalLM.from_pretrained("/output/path")
55
- tokenizer = LlamaTokenizer.from_pretrained("/output/path")
56
- ```
57
-
58
- Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
59
- come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
60
- """
61
-
62
- NUM_SHARDS = {"7B": 1}
63
-
64
-
65
- def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
66
- return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
67
-
68
-
69
- def read_json(path):
70
- with open(path, "r") as f:
71
- return json.load(f)
72
-
73
-
74
- def write_json(text, path):
75
- with open(path, "w") as f:
76
- json.dump(text, f)
77
-
78
-
79
- def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True):
80
- # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
81
- if not os.path.isfile(os.path.join(input_base_path, "params.json")):
82
- input_base_path = os.path.join(input_base_path, model_size)
83
-
84
- os.makedirs(model_path, exist_ok=True)
85
- tmp_model_path = os.path.join(model_path, "tmp")
86
- os.makedirs(tmp_model_path, exist_ok=True)
87
-
88
- params = read_json(os.path.join(input_base_path, "params.json"))
89
- num_shards = NUM_SHARDS[model_size]
90
-
91
- # For some reason this is a string in the params.json
92
- sliding_window = int(params["sliding_window"])
93
- n_layers = params["n_layers"]
94
- n_heads = params["n_heads"]
95
- n_heads_per_shard = n_heads // num_shards
96
- dim = params["dim"]
97
- dims_per_head = dim // n_heads
98
- base = params.get("rope_theta", 10000.0)
99
- inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
100
- max_position_embeddings = 4096 * 8
101
-
102
- if tokenizer_path is not None:
103
- tokenizer = tokenizer_class(tokenizer_path)
104
- tokenizer.save_pretrained(model_path)
105
- vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
106
-
107
- if "n_kv_heads" in params:
108
- num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
109
- num_local_key_value_heads = num_key_value_heads // num_shards
110
- key_value_dim = dims_per_head * num_local_key_value_heads
111
- else: # compatibility with other checkpoints
112
- num_key_value_heads = n_heads
113
- num_local_key_value_heads = n_heads_per_shard
114
- key_value_dim = dim
115
-
116
- # permute for sliced rotary
117
- def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
118
- return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
119
-
120
- print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
121
- # Load weights
122
- loaded = [
123
- torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
124
- for i in range(num_shards)
125
- ]
126
- param_count = 0
127
- index_dict = {"weight_map": {}}
128
- for layer_i in range(n_layers):
129
- filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
130
-
131
- # Sharded
132
- # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
133
- # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
134
- # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
135
-
136
- state_dict = {
137
- f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
138
- f"layers.{layer_i}.attention_norm.weight"
139
- ].clone(),
140
- f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
141
- f"layers.{layer_i}.ffn_norm.weight"
142
- ].clone(),
143
- }
144
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
145
- torch.cat(
146
- [
147
- loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
148
- for i in range(num_shards)
149
- ],
150
- dim=0,
151
- ).reshape(dim, dim)
152
- )
153
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
154
- torch.cat(
155
- [
156
- loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
157
- num_local_key_value_heads, dims_per_head, dim
158
- )
159
- for i in range(num_shards)
160
- ],
161
- dim=0,
162
- ).reshape(key_value_dim, dim),
163
- num_key_value_heads,
164
- key_value_dim,
165
- dim,
166
- )
167
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
168
- [
169
- loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(num_local_key_value_heads, dims_per_head, dim)
170
- for i in range(num_shards)
171
- ],
172
- dim=0,
173
- ).reshape(key_value_dim, dim)
174
-
175
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
176
- [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
177
- )
178
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
179
- [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
180
- )
181
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
182
- [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
183
- )
184
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
185
- [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
186
- )
187
-
188
- state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
189
- for k, v in state_dict.items():
190
- index_dict["weight_map"][k] = filename
191
- param_count += v.numel()
192
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
193
-
194
- filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
195
- state_dict = {
196
- "model.norm.weight": loaded[0]["norm.weight"],
197
- "model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1),
198
- "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
199
- }
200
-
201
- for k, v in state_dict.items():
202
- index_dict["weight_map"][k] = filename
203
- param_count += v.numel()
204
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
205
-
206
- # Write configs
207
- index_dict["metadata"] = {"total_size": param_count * 2}
208
- write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
209
- config = MistralConfig(
210
- hidden_size=dim,
211
- intermediate_size=params["hidden_dim"],
212
- num_attention_heads=params["n_heads"],
213
- num_hidden_layers=params["n_layers"],
214
- rms_norm_eps=params["norm_eps"],
215
- num_key_value_heads=num_key_value_heads,
216
- vocab_size=vocab_size,
217
- rope_theta=base,
218
- max_position_embeddings=max_position_embeddings,
219
- sliding_window=sliding_window,
220
- )
221
- config.save_pretrained(tmp_model_path)
222
-
223
- # Make space so we can load the model properly now.
224
- del state_dict
225
- del loaded
226
- gc.collect()
227
-
228
- print("Loading the checkpoint in a Mistral model.")
229
- model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
230
- # Avoid saving this as part of the config.
231
- del model.config._name_or_path
232
- model.config.torch_dtype = torch.float16
233
- print("Saving in the Transformers format.")
234
- model.save_pretrained(model_path, safe_serialization=safe_serialization)
235
- shutil.rmtree(tmp_model_path)
236
-
237
-
238
- def write_tokenizer(tokenizer_path, input_tokenizer_path):
239
- # Initialize the tokenizer based on the `spm` model
240
- print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
241
- tokenizer = tokenizer_class(input_tokenizer_path)
242
- tokenizer.save_pretrained(tokenizer_path)
243
-
244
-
245
- def main():
246
- parser = argparse.ArgumentParser()
247
- parser.add_argument(
248
- "--input_dir",
249
- help="Location of Mistral weights, which contains tokenizer.model and model folders",
250
- )
251
- parser.add_argument(
252
- "--model_size",
253
- choices=["7B", "tokenizer_only"],
254
- help="'f' models correspond to the finetuned versions, and are specific to the Mistral2 official release. For more details on Mistral2, checkout the original repo: https://huggingface.co/meta-mistral",
255
- )
256
- parser.add_argument(
257
- "--output_dir",
258
- help="Location to write HF model and tokenizer",
259
- )
260
- parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
261
- args = parser.parse_args()
262
- spm_path = os.path.join(args.input_dir, "tokenizer.model")
263
- if args.model_size != "tokenizer_only":
264
- write_model(
265
- model_path=args.output_dir,
266
- input_base_path=args.input_dir,
267
- model_size=args.model_size,
268
- safe_serialization=args.safe_serialization,
269
- tokenizer_path=spm_path,
270
- )
271
- else:
272
- write_tokenizer(args.output_dir, spm_path)
273
-
274
-
275
- if __name__ == "__main__":
276
- main()