File size: 987 Bytes
607d6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer
import json
import torch
from torch import nn
import tqdm

embed_dict = nn.ParameterDict()
bsz = 8

with open("id_to_str.json") as f:
    data = json.load(f)

keys = list(data.keys())
bar = tqdm.tqdm(range(len(keys)//bsz))

proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble")

for i in bar:
    batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]]
    tokenized = tokenizer(batch)
    for k in range(bsz):
        if len(tokenized[k]) > 16:
            tokenizer.decode(tokenized[k])

    batch = proc(text=batch, return_tensors="pt")
    output = model(**batch)
    for k, key in enumerate(keys[i*bsz:(i+1)*bsz]):
        embed_dict[key] = output.pooler_output[k, :]

torch.save(embed_dict.state_dict(), "embeds.pt")