File size: 846 Bytes
c5a9149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import numpy as np

import jax
import jax.numpy as jnp

from transformers import AutoTokenizer
from transformers import FlaxGPT2LMHeadModel
from transformers import GPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained("../")
tokenizer.pad_token = tokenizer.eos_token

model_fx = FlaxGPT2LMHeadModel.from_pretrained("../")

# def to_f32(t):
#     return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)

# model_fx.params = to_f32(model_fx.params)
# model_fx.save_pretrained("./fx")

model_pt = GPT2LMHeadModel.from_pretrained("../", from_flax=True)
model_pt.save_pretrained("./pt")

input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
input_ids_pt = torch.tensor(input_ids)

logits_pt = model_pt(input_ids_pt).logits
print(logits_pt)
logits_fx = model_fx(input_ids).logits
print(logits_fx)