File size: 859 Bytes
752f635 bab84ab 752f635 bab84ab 752f635 bab84ab 752f635 bab84ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, FlaxAutoModelForMaskedLM
from datasets import load_dataset
from wechsel import WECHSEL, load_embeddings
source_tokenizer = AutoTokenizer.from_pretrained("roberta-large")
model = AutoModelForMaskedLM.from_pretrained("roberta-large")
target_tokenizer = AutoTokenizer.from_pretrained("./")
wechsel = WECHSEL(
load_embeddings("en"),
load_embeddings("fi"),
bilingual_dictionary="finnish"
)
target_embeddings, info = wechsel.apply(
source_tokenizer,
target_tokenizer,
model.get_input_embeddings().weight.detach().numpy(),
)
model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings).to(torch.float32)
model.save_pretrained("./")
# flax_model = FlaxAutoModelForMaskedLM.from_pretrained("./", from_pt=True)
# flax_model.save_pretrained("./")
|