RyanYr commited on
Commit
daa9896
1 Parent(s): 0c01d6d

Training in progress, step 300, checkpoint

Browse files
last-checkpoint/config.json CHANGED
@@ -30,7 +30,7 @@
30
  "rope_theta": 10000.0,
31
  "sliding_window": 4096,
32
  "torch_dtype": "bfloat16",
33
- "transformers_version": "4.44.2",
34
  "use_cache": false,
35
  "vocab_size": 256001
36
  }
 
30
  "rope_theta": 10000.0,
31
  "sliding_window": 4096,
32
  "torch_dtype": "bfloat16",
33
+ "transformers_version": "4.43.4",
34
  "use_cache": false,
35
  "vocab_size": 256001
36
  }
last-checkpoint/generation_config.json CHANGED
@@ -7,5 +7,5 @@
7
  107
8
  ],
9
  "pad_token_id": 0,
10
- "transformers_version": "4.44.2"
11
  }
 
7
  107
8
  ],
9
  "pad_token_id": 0,
10
+ "transformers_version": "4.43.4"
11
  }
last-checkpoint/global_step300/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a4dfa001510f1245942dafc19e25f44adac764fdeb5d88d10eddfcdf09286fea
3
  size 7843036668
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:490395e6ddb61a55d1f5621fad1deedde177291f376a1cab36fc05a41e4c8ca6
3
  size 7843036668
last-checkpoint/global_step300/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0100073d56a06b572003930c4647b70a538edc7ac67241b40f5f833e3040b10f
3
  size 7843043580
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cb520c96f9c73c5d0868140ecc4ea16fea2f0da04ee42d790c9a77c67e41f62
3
  size 7843043580
last-checkpoint/global_step300/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0fe3b6430e4ff02932e1ebb8820df6c78493f72a94d73674b3a337a24567dfb1
3
  size 7843043004
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27e83bf3193945caa8218113e40161c4401aee5092fda15fbdb265676a30c2fa
3
  size 7843043004
last-checkpoint/global_step300/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b0bfb0cc7923c504ff6a6eb192b3921cca2b134cd36cb6b634c59853033cab99
3
  size 7843043388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44d4dd32831c7a9646cd49d1a6dbc7df7f8738372bd5a23dd0f3c1b95f5118cb
3
  size 7843043388
last-checkpoint/global_step300/mp_rank_00_model_states.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:df4c97c6d836cd16ac31fff03c6e56d8d7867734ff9d9ee4f55f68df14601d63
3
  size 5228775200
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f855eb13a1f303c00e73dda783c536d3bba9a9118989fde748f77807ba0c98bd
3
  size 5228775200
last-checkpoint/latest CHANGED
@@ -1 +1 @@
1
- global_step1606
 
1
+ global_step300
last-checkpoint/model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8b6ef8a2b6bc11bee46b965fda81f699192d9ea76dcc92e5446ae71eb00afd97
3
  size 4988030368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a95052f77e0ee3c96f0f7cae869217e2cc93260d3e50d9852b75c2f0521adcb
3
  size 4988030368
last-checkpoint/model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8decfa8858e28a8f92e2d8ec671fa359841b2180e19beb1dd64dc77862ba7ef1
3
  size 1420344488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c5a4fd382cb8a85c300b8a5648653039f0191901b30d466d7a75044be659a1d
3
  size 1420344488
last-checkpoint/rng_state_0.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbe0d720c4c75a6a04213fa3b64bacbe794718a53e2b56ebb67a1a795014dfad
3
  size 15024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92cc13315f24c28015d695b6cde08bb1cd6fea4cbc435998485ed6fbe4c91285
3
  size 15024
last-checkpoint/rng_state_1.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:72452d3138d0ca2ff89429e3294a834ae7a68e8596fc757735ca56ae52509d57
3
  size 15024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4c154b6a63e0b1f98f7d2847944398f99f1657d35e8eddf7fdf0ae2c24b0552
3
  size 15024
last-checkpoint/rng_state_2.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f36e306fb8ebcf53a167bfd6c9af74db410a269ada1e619e3e816f5269543b9d
3
  size 15024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f784c6a9507b51189f2caffbd178ea9882103b75852e31c15f47fdae6a43af1d
3
  size 15024
last-checkpoint/rng_state_3.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb47ce0c6f815a6f8302b0e3819b4c2315ca71dae3138d97fdceb765cdd0a039
3
  size 15024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34b023e05bc2d12b91dc436d4922b990d50ec8dc56d40dc3e36b3bb34fc81341
3
  size 15024
last-checkpoint/scheduler.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5004e0a8a423ac7767518d3c523b7ff1e2ad0ddfefc734167783821a2bfb69a6
3
  size 1064
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c9281bdbed11a5fa989179d3990f8ea1577b41ba21b300ef8adcb469edf99b5
3
  size 1064
last-checkpoint/tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ad81132a729860bdb9e4d2e22e3ae09f317f539aac46d8acc4e17c9412f0870
3
  size 17525539
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:987ad1b8e70d3ba898f587a434ba487d544c2800b1b9dcf020ffcbe7a5ac1d12
3
  size 17525539
last-checkpoint/trainer_state.json CHANGED
The diff for this file is too large to render. See raw diff
 
last-checkpoint/training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:46f9a52e7c282501a8e3028dacc99bc68c098efa306a7cbd4983a9b5a70c118a
3
  size 7096
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45001beea15a94e59d46abe242dc94b7cf1ec836fd936b0dfb68a48a54f36abe
3
  size 7096
last-checkpoint/zero_to_fp32.py CHANGED
@@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir):
191
  return zero_stage, world_size, fp32_flat_groups
192
 
193
 
194
- def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
195
  """
196
  Returns fp32 state_dict reconstructed from ds checkpoint
197
 
@@ -211,11 +211,9 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_
211
  print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
 
213
  if zero_stage <= 2:
214
- return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
215
- exclude_frozen_parameters)
216
  elif zero_stage == 3:
217
- return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
218
- exclude_frozen_parameters)
219
 
220
 
221
  def _zero2_merge_frozen_params(state_dict, zero_model_states):
@@ -250,11 +248,6 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states):
250
  print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
251
 
252
 
253
- def _has_callable(obj, fn):
254
- attr = getattr(obj, fn, None)
255
- return callable(attr)
256
-
257
-
258
  def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
259
  param_shapes = zero_model_states[0].param_shapes
260
 
@@ -294,7 +287,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
294
  avail_numel = full_single_fp32_vector.numel()
295
  for name, shape in shapes.items():
296
 
297
- unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
298
  total_numel += unpartitioned_numel
299
  total_params += 1
300
 
@@ -328,8 +321,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
328
  print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
329
 
330
 
331
- def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
332
- exclude_frozen_parameters):
333
  state_dict = OrderedDict()
334
 
335
  # buffers
@@ -338,8 +330,7 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer
338
  if debug:
339
  print(f"added {len(buffers)} buffers")
340
 
341
- if not exclude_frozen_parameters:
342
- _zero2_merge_frozen_params(state_dict, zero_model_states)
343
 
344
  _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
345
 
@@ -448,8 +439,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
448
  print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
449
 
450
 
451
- def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
452
- exclude_frozen_parameters):
453
  state_dict = OrderedDict()
454
 
455
  # buffers
@@ -458,8 +448,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
458
  if debug:
459
  print(f"added {len(buffers)} buffers")
460
 
461
- if not exclude_frozen_parameters:
462
- _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
463
 
464
  _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
465
 
@@ -471,7 +460,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer
471
  return state_dict
472
 
473
 
474
- def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
475
  """
476
  Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
477
  ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
@@ -480,7 +469,6 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
480
  Args:
481
  - ``checkpoint_dir``: path to the desired checkpoint folder
482
  - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
483
- - ``exclude_frozen_parameters``: exclude frozen parameters
484
 
485
  Returns:
486
  - pytorch ``state_dict``
@@ -518,10 +506,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f
518
  if not os.path.isdir(ds_checkpoint_dir):
519
  raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
520
 
521
- return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
522
 
523
 
524
- def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
525
  """
526
  Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
527
  loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
@@ -530,10 +518,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=
530
  - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
531
  - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
532
  - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
533
- - ``exclude_frozen_parameters``: exclude frozen parameters
534
  """
535
 
536
- state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
537
  print(f"Saving fp32 state dict to {output_file}")
538
  torch.save(state_dict, output_file)
539
 
@@ -592,13 +579,9 @@ if __name__ == "__main__":
592
  type=str,
593
  default=None,
594
  help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
595
- parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
596
  parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
597
  args = parser.parse_args()
598
 
599
  debug = args.debug
600
 
601
- convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
602
- args.output_file,
603
- tag=args.tag,
604
- exclude_frozen_parameters=args.exclude_frozen_parameters)
 
191
  return zero_stage, world_size, fp32_flat_groups
192
 
193
 
194
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
195
  """
196
  Returns fp32 state_dict reconstructed from ds checkpoint
197
 
 
211
  print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
 
213
  if zero_stage <= 2:
214
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
 
215
  elif zero_stage == 3:
216
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
 
217
 
218
 
219
  def _zero2_merge_frozen_params(state_dict, zero_model_states):
 
248
  print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
249
 
250
 
 
 
 
 
 
251
  def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
252
  param_shapes = zero_model_states[0].param_shapes
253
 
 
287
  avail_numel = full_single_fp32_vector.numel()
288
  for name, shape in shapes.items():
289
 
290
+ unpartitioned_numel = shape.numel()
291
  total_numel += unpartitioned_numel
292
  total_params += 1
293
 
 
321
  print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
322
 
323
 
324
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
 
325
  state_dict = OrderedDict()
326
 
327
  # buffers
 
330
  if debug:
331
  print(f"added {len(buffers)} buffers")
332
 
333
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
 
334
 
335
  _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
336
 
 
439
  print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
440
 
441
 
442
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
 
443
  state_dict = OrderedDict()
444
 
445
  # buffers
 
448
  if debug:
449
  print(f"added {len(buffers)} buffers")
450
 
451
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
 
452
 
453
  _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
454
 
 
460
  return state_dict
461
 
462
 
463
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
464
  """
465
  Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
466
  ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
 
469
  Args:
470
  - ``checkpoint_dir``: path to the desired checkpoint folder
471
  - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
 
472
 
473
  Returns:
474
  - pytorch ``state_dict``
 
506
  if not os.path.isdir(ds_checkpoint_dir):
507
  raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
508
 
509
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
510
 
511
 
512
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
513
  """
514
  Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
515
  loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
 
518
  - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
519
  - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
520
  - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
 
521
  """
522
 
523
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
524
  print(f"Saving fp32 state dict to {output_file}")
525
  torch.save(state_dict, output_file)
526
 
 
579
  type=str,
580
  default=None,
581
  help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
 
582
  parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
583
  args = parser.parse_args()
584
 
585
  debug = args.debug
586
 
587
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)