ikala-ray's picture
Upload transform.py
3055d59
import torch
from safetensors import safe_open
def transform(open_clip_safe_tensor_path):
tensors = {}
with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f:
metadata = f.metadata()
for k in f.keys():
ignore_tensor = False
first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.')
new_key = first_prefix.replace('.trunk.', '.encoder.')
new_key = new_key.replace('.blocks.', '.layers.')
new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.')
if 'vision' in new_key:
new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.')
new_key = new_key.replace('.norm', '.layer_norm')
# mappings extracted from timm code
# ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2')
new_key = new_key.replace('.proj.c_fc', '.mlp.fc1')
new_key = new_key.replace('.proj.c_proj', '.mlp.fc2')
new_key = new_key.replace('.attn.proj', '.self_attn.out_proj')
if 'qkv' in new_key:
qkv_weight = f.get_tensor(k)
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
ignore_tensor = True
# ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
# ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
# ("pos_embed", "vit.embeddings.position_embeddings"),
# ['vision_model.embeddings.patch_embedding.weight',
# 'vision_model.post_layernorm.weight', 'vision_model.embeddings.position_embedding.weight',
# 'vision_model.embeddings.class_embedding', 'vision_model.pre_layrnorm.weight',
# 'vision_model.pre_layrnorm.bias', 'vision_model.post_layernorm.bias']
# vision_model.encoder.layer_norm.bias
# vision_model.encoder.layer_norm.weight
# vision_model.encoder.patch_embed.proj.bias
# vision_model.encoder.patch_embed.proj.weight
# vision_model.encoder.pos_embed
replacement_keys = [
('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'),
('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'),
('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'),
('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'),
('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'),
]
for old_, new_ in replacement_keys:
if old_ in new_key:
new_key = new_key.replace(old_, new_)
elif 'text' in new_key:
# text_model.encoder.layers.0.ln_1.bias ->
# text_model.encoder.layers.0.layer_norm1.bias
# text_model.encoder.layers.1.mlp.c_fc.bias ->
# text_model.encoder.layers.11.mlp.fc1.weight
new_key = new_key.replace('.ln_2.', '.layer_norm2.')
new_key = new_key.replace('.ln_1.', '.layer_norm1.')
new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1')
new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2')
new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.')
new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj')
if 'qkv' in new_key:
# text_model.encoder.layers.0.self_attn.qkv.weight
# text_model.encoder.layers.4.self_attn.v_proj.weight
qkv_weight = f.get_tensor(k)
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach()
tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach()
tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach()
ignore_tensor = True
replacement_keys = [
('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'),
('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'),
('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'),
('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'),
('text_model.text_projection.weight', 'text_projection.weight'),
]
for old_, new_ in replacement_keys:
if old_ in new_key:
new_key = new_key.replace(old_, new_)
if 'vision' in new_key and 'img_projector' in new_key:
print(new_key)
if ignore_tensor:
continue
tensors[new_key] = f.get_tensor(k)
if 'vision_model.embeddings.position_embedding' in new_key:
tensor = tensors[new_key][0]
new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1]))
new_tensor[:tensor.shape[0], :] = tensor
new_tensor[-1, :] = tensor[-1,:]
tensors[new_key] = new_tensor
# siglip doesn't seem to have any pre norm layer so we have to make it identity for now
tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape,
dtype=tensors['vision_model.pre_layrnorm.bias'].dtype,
device=tensors['vision_model.pre_layrnorm.bias'].device)
# this wasn't used
tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02,
size=tensors['vision_model.pre_layrnorm.bias'].shape)
return tensors