|
import torch |
|
from safetensors import safe_open |
|
|
|
def transform(open_clip_safe_tensor_path): |
|
tensors = {} |
|
with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f: |
|
metadata = f.metadata() |
|
for k in f.keys(): |
|
ignore_tensor = False |
|
first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.') |
|
new_key = first_prefix.replace('.trunk.', '.encoder.') |
|
new_key = new_key.replace('.blocks.', '.layers.') |
|
new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.') |
|
if 'vision' in new_key: |
|
new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.') |
|
new_key = new_key.replace('.norm', '.layer_norm') |
|
|
|
|
|
new_key = new_key.replace('.proj.c_fc', '.mlp.fc1') |
|
new_key = new_key.replace('.proj.c_proj', '.mlp.fc2') |
|
new_key = new_key.replace('.attn.proj', '.self_attn.out_proj') |
|
if 'qkv' in new_key: |
|
qkv_weight = f.get_tensor(k) |
|
q, k, v = torch.chunk(qkv_weight, 3, dim=0) |
|
tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach() |
|
tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach() |
|
tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach() |
|
ignore_tensor = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
replacement_keys = [ |
|
('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'), |
|
('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'), |
|
('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'), |
|
('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'), |
|
('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'), |
|
] |
|
for old_, new_ in replacement_keys: |
|
if old_ in new_key: |
|
new_key = new_key.replace(old_, new_) |
|
elif 'text' in new_key: |
|
|
|
|
|
|
|
|
|
new_key = new_key.replace('.ln_2.', '.layer_norm2.') |
|
new_key = new_key.replace('.ln_1.', '.layer_norm1.') |
|
new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1') |
|
new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2') |
|
new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.') |
|
new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj') |
|
if 'qkv' in new_key: |
|
|
|
|
|
qkv_weight = f.get_tensor(k) |
|
q, k, v = torch.chunk(qkv_weight, 3, dim=0) |
|
tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach() |
|
tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach() |
|
tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach() |
|
ignore_tensor = True |
|
replacement_keys = [ |
|
('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'), |
|
('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'), |
|
('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'), |
|
('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'), |
|
('text_model.text_projection.weight', 'text_projection.weight'), |
|
] |
|
for old_, new_ in replacement_keys: |
|
if old_ in new_key: |
|
new_key = new_key.replace(old_, new_) |
|
if 'vision' in new_key and 'img_projector' in new_key: |
|
print(new_key) |
|
|
|
if ignore_tensor: |
|
continue |
|
tensors[new_key] = f.get_tensor(k) |
|
if 'vision_model.embeddings.position_embedding' in new_key: |
|
tensor = tensors[new_key][0] |
|
new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1])) |
|
new_tensor[:tensor.shape[0], :] = tensor |
|
new_tensor[-1, :] = tensor[-1,:] |
|
tensors[new_key] = new_tensor |
|
|
|
tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape, |
|
dtype=tensors['vision_model.pre_layrnorm.bias'].dtype, |
|
device=tensors['vision_model.pre_layrnorm.bias'].device) |
|
|
|
tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02, |
|
size=tensors['vision_model.pre_layrnorm.bias'].shape) |
|
return tensors |