theimageconvert2 / transformers_4_35_0 /modeling_tf_pytorch_utils.py
mart9992's picture
m
4c65bff
raw
history blame
24.8 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
import os
import re
import numpy
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
from .utils import transpose as transpose_func
logger = logging.get_logger(__name__)
class TransposeType(ExplicitEnum):
"""
Possible ...
"""
NO = "no"
SIMPLE = "simple"
CONV1D = "conv1d"
CONV2D = "conv2d"
def convert_tf_weight_name_to_pt_weight_name(
tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
):
"""
Convert a TF 2.0 model variable name in a pytorch model weight name.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
return tuple with:
- pytorch model weight name
- transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
transposed with regards to each other
"""
if name_scope is not None:
if not tf_name.startswith(name_scope):
raise ValueError(
f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
"in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
)
tf_name = tf_name[len(name_scope) :]
tf_name = tf_name.lstrip("/")
tf_name = tf_name.replace(":0", "") # device ids
tf_name = re.sub(
r"/[^/]*___([^/]*)/", r"/\1/", tf_name
) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
tf_name = tf_name.replace(
"_._", "/"
) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
# Some weights have a single name without "/" such as final_logits_bias in BART
if len(tf_name) > 1:
tf_name = tf_name[1:] # Remove level zero
tf_weight_shape = list(tf_weight_shape)
# When should we transpose the weights
if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
transpose = TransposeType.CONV2D
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
transpose = TransposeType.CONV1D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
or "emb_projs" in tf_name
or "out_projs" in tf_name
):
transpose = TransposeType.SIMPLE
else:
transpose = TransposeType.NO
# Convert standard TF2.0 names in PyTorch names
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
tf_name[-1] = "weight"
if tf_name[-1] == "beta":
tf_name[-1] = "bias"
# The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
# Remove prefix if needed
tf_name = ".".join(tf_name)
if start_prefix_to_remove:
tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
return tf_name, transpose
def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
"""
Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
framework agnostic way.
"""
if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
weight = transpose_func(weight, axes=axes)
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
weight = transpose_func(weight, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
weight = transpose_func(weight)
if match_shape is None:
return weight
if len(match_shape) < len(weight.shape):
weight = squeeze(weight)
elif len(match_shape) > len(weight.shape):
weight = expand_dims(weight, axis=0)
if list(match_shape) != list(weight.shape):
try:
weight = reshape(weight, match_shape)
except AssertionError as e:
e.args += (match_shape, match_shape)
raise e
return weight
#####################
# PyTorch => TF 2.0 #
#####################
def load_pytorch_checkpoint_in_tf2_model(
tf_model,
pytorch_checkpoint_path,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
):
"""Load pytorch checkpoints in a TF 2.0 model"""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Treats a single file as a collection of shards with 1 shard.
if isinstance(pytorch_checkpoint_path, str):
pytorch_checkpoint_path = [pytorch_checkpoint_path]
# Loads all shards into a single state dictionary
pt_state_dict = {}
for path in pytorch_checkpoint_path:
pt_path = os.path.abspath(path)
logger.info(f"Loading PyTorch weights from {pt_path}")
pt_state_dict.update(torch.load(pt_path, map_location="cpu"))
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
return load_pytorch_weights_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
"""Load pytorch checkpoints in a TF 2.0 model"""
pt_state_dict = pt_model.state_dict()
return load_pytorch_weights_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
)
def load_pytorch_weights_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
):
"""Load pytorch state_dict in a TF 2.0 model."""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
return load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)
def load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
):
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
import tensorflow as tf
from keras import backend as K
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
if _prefix is None:
_prefix = ""
if tf_inputs:
with tf.name_scope(_prefix):
tf_model(tf_inputs, training=False) # Make sure model is built
# Convert old format to new format if needed from a PyTorch state_dict
tf_keys_to_pt_keys = {}
for key in pt_state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if "running_var" in key:
new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean")
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
new_key = ".".join(key_components)
if new_key is None:
new_key = key
tf_keys_to_pt_keys[new_key] = key
# Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
# In PT, the derived models (with heads) use the base model class as the stem instead,
# and there is no MainLayer class. This means that TF base classes have one
# extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
start_prefix_to_remove = ""
if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + "."
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0
all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
missing_keys = []
mismatched_keys = []
is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name(
sw_name,
start_prefix_to_remove=start_prefix_to_remove,
tf_weight_shape=symbolic_weight.shape,
name_scope=_prefix,
)
if tf_to_pt_weight_rename is not None:
name = tf_to_pt_weight_rename(name)
# Find associated numpy array in pytorch model state dict
if name not in tf_keys_to_pt_keys:
if allow_missing_keys:
missing_keys.append(name)
continue
elif tf_model._keys_to_ignore_on_load_missing is not None:
# authorized missing keys don't have to be loaded
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
continue
raise AttributeError(f"{name} not found in PyTorch model")
state_dict_name = tf_keys_to_pt_keys[name]
if is_safetensor_archive:
array = pt_state_dict.get_tensor(state_dict_name)
else:
array = pt_state_dict[state_dict_name]
try:
array = apply_transpose(transpose, array, symbolic_weight.shape)
except tf.errors.InvalidArgumentError as e:
if not ignore_mismatched_sizes:
error_msg = str(e)
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise tf.errors.InvalidArgumentError(error_msg)
else:
mismatched_keys.append((name, array.shape, symbolic_weight.shape))
continue
tf_loaded_numel += tensor_size(array)
K.set_value(symbolic_weight, array)
del array # Immediately free memory to keep peak usage as low as possible
all_pytorch_weights.discard(name)
logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
unexpected_keys = list(all_pytorch_weights)
if tf_model._keys_to_ignore_on_load_missing is not None:
for pat in tf_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
}
return tf_model, loading_info
return tf_model
#####################
# TF 2.0 => PyTorch #
#####################
def load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""
Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
import transformers
from .modeling_tf_utils import load_tf_weights
logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")
# Instantiate and load the associated TF 2.0 model
tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning
tf_model_class = getattr(transformers, tf_model_class_name)
tf_model = tf_model_class(pt_model.config)
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built
load_tf_weights(tf_model, tf_checkpoint_path)
return load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
"""Load TF 2.0 model in a pytorch model"""
weights = tf_model.weights
return load_tf2_weights_in_pytorch_model(
pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
"""Load TF2.0 symbolic weights in a PyTorch model"""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
return load_tf2_state_dict_in_pytorch_model(
pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
import torch
new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters())
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
start_prefix_to_remove = ""
if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
start_prefix_to_remove = pt_model.base_model_prefix + "."
# Build a map from potential PyTorch weight names to TF 2.0 Variables
tf_weights_map = {}
for name, tf_weight in tf_state_dict.items():
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
)
tf_weights_map[pt_name] = (tf_weight, transpose)
all_tf_weights = set(tf_weights_map.keys())
loaded_pt_weights_data_ptr = {}
missing_keys_pt = []
for pt_weight_name, pt_weight in current_pt_params_dict.items():
# Handle PyTorch shared weight ()not duplicated in TF 2.0
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
continue
pt_weight_name_to_check = pt_weight_name
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = pt_weight_name.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
pt_weight_name_to_check = ".".join(key_components)
# Find associated numpy array in pytorch model state dict
if pt_weight_name_to_check not in tf_weights_map:
if allow_missing_keys:
missing_keys_pt.append(pt_weight_name)
continue
raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
array, transpose = tf_weights_map[pt_weight_name_to_check]
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
if numpy.isscalar(array):
array = numpy.array(array)
if not is_torch_tensor(array) and not is_numpy_array(array):
array = array.numpy()
if is_numpy_array(array):
# Convert to torch tensor
array = torch.from_numpy(array)
new_pt_params_dict[pt_weight_name] = array
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
all_tf_weights.discard(pt_weight_name)
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
missing_keys += missing_keys_pt
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if pt_model._keys_to_ignore_on_load_missing is not None:
for pat in pt_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
if pt_model._keys_to_ignore_on_load_unexpected is not None:
for pat in pt_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
" (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
" to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
" TFBertForSequenceClassification model)."
)
else:
logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
return pt_model, loading_info
return pt_model