import torch from safetensors import safe_open def transform(open_clip_safe_tensor_path): tensors = {} with safe_open(open_clip_safe_tensor_path, framework="pt", device=0) as f: metadata = f.metadata() for k in f.keys(): ignore_tensor = False first_prefix = k.replace('visual.', 'vision_model.').replace('text.', 'text_model.') new_key = first_prefix.replace('.trunk.', '.encoder.') new_key = new_key.replace('.blocks.', '.layers.') new_key = new_key.replace('.transformer.resblocks.', '.encoder.layers.') if 'vision' in new_key: new_key = new_key.replace('.self_attn.out_proj.', '.attn.proj.') new_key = new_key.replace('.norm', '.layer_norm') # mappings extracted from timm code # ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2') new_key = new_key.replace('.proj.c_fc', '.mlp.fc1') new_key = new_key.replace('.proj.c_proj', '.mlp.fc2') new_key = new_key.replace('.attn.proj', '.self_attn.out_proj') if 'qkv' in new_key: qkv_weight = f.get_tensor(k) q, k, v = torch.chunk(qkv_weight, 3, dim=0) tensors[new_key.replace('.attn.qkv', '.self_attn.q_proj')] = q.clone().detach() tensors[new_key.replace('.attn.qkv', '.self_attn.k_proj')] = k.clone().detach() tensors[new_key.replace('.attn.qkv', '.self_attn.v_proj')] = v.clone().detach() ignore_tensor = True # ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), # ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), # ("pos_embed", "vit.embeddings.position_embeddings"), # ['vision_model.embeddings.patch_embedding.weight', # 'vision_model.post_layernorm.weight', 'vision_model.embeddings.position_embedding.weight', # 'vision_model.embeddings.class_embedding', 'vision_model.pre_layrnorm.weight', # 'vision_model.pre_layrnorm.bias', 'vision_model.post_layernorm.bias'] # vision_model.encoder.layer_norm.bias # vision_model.encoder.layer_norm.weight # vision_model.encoder.patch_embed.proj.bias # vision_model.encoder.patch_embed.proj.weight # vision_model.encoder.pos_embed replacement_keys = [ ('vision_model.encoder.patch_embed.proj.weight', 'vision_model.embeddings.patch_embedding.weight'), ('vision_model.encoder.pos_embed', 'vision_model.embeddings.position_embedding.weight'), ('vision_model.encoder.patch_embed.proj.bias', 'vision_model.pre_layrnorm.bias'), ('vision_model.encoder.layer_norm.bias', 'vision_model.post_layernorm.bias'), ('vision_model.encoder.layer_norm.weight', 'vision_model.post_layernorm.weight'), ] for old_, new_ in replacement_keys: if old_ in new_key: new_key = new_key.replace(old_, new_) elif 'text' in new_key: # text_model.encoder.layers.0.ln_1.bias -> # text_model.encoder.layers.0.layer_norm1.bias # text_model.encoder.layers.1.mlp.c_fc.bias -> # text_model.encoder.layers.11.mlp.fc1.weight new_key = new_key.replace('.ln_2.', '.layer_norm2.') new_key = new_key.replace('.ln_1.', '.layer_norm1.') new_key = new_key.replace('.mlp.c_fc', '.mlp.fc1') new_key = new_key.replace('.mlp.c_proj', '.mlp.fc2') new_key = new_key.replace('.attn.in_proj_', '.self_attn.qkv.') new_key = new_key.replace('.attn.out_proj', '.self_attn.out_proj') if 'qkv' in new_key: # text_model.encoder.layers.0.self_attn.qkv.weight # text_model.encoder.layers.4.self_attn.v_proj.weight qkv_weight = f.get_tensor(k) q, k, v = torch.chunk(qkv_weight, 3, dim=0) tensors[new_key.replace('.self_attn.qkv', '.self_attn.q_proj')] = q.clone().detach() tensors[new_key.replace('.self_attn.qkv', '.self_attn.k_proj')] = k.clone().detach() tensors[new_key.replace('.self_attn.qkv', '.self_attn.v_proj')] = v.clone().detach() ignore_tensor = True replacement_keys = [ ('text_model.positional_embedding', 'text_model.embeddings.position_embedding.weight'), ('text_model.token_embedding.weight', 'text_model.embeddings.token_embedding.weight'), ('text_model.ln_final.bias', 'text_model.final_layer_norm.bias'), ('text_model.ln_final.weight', 'text_model.final_layer_norm.weight'), ('text_model.text_projection.weight', 'text_projection.weight'), ] for old_, new_ in replacement_keys: if old_ in new_key: new_key = new_key.replace(old_, new_) if 'vision' in new_key and 'img_projector' in new_key: print(new_key) if ignore_tensor: continue tensors[new_key] = f.get_tensor(k) if 'vision_model.embeddings.position_embedding' in new_key: tensor = tensors[new_key][0] new_tensor = torch.zeros((tensor.shape[0]+1, tensor.shape[1])) new_tensor[:tensor.shape[0], :] = tensor new_tensor[-1, :] = tensor[-1,:] tensors[new_key] = new_tensor # siglip doesn't seem to have any pre norm layer so we have to make it identity for now tensors['vision_model.pre_layrnorm.weight'] = torch.ones(tensors['vision_model.pre_layrnorm.bias'].shape, dtype=tensors['vision_model.pre_layrnorm.bias'].dtype, device=tensors['vision_model.pre_layrnorm.bias'].device) # this wasn't used tensors['vision_model.embeddings.class_embedding'] = tensor = torch.normal(mean=0.0, std=0.02, size=tensors['vision_model.pre_layrnorm.bias'].shape) return tensors