File size: 7,532 Bytes
a8edea3 785a54b a8edea3 fb9114d a8edea3 785a54b a8edea3 785a54b caff12e 73ae988 caff12e cdb0880 caff12e cdb0880 caff12e a8edea3 785a54b a8edea3 785a54b dfb402a 785a54b caff12e a8edea3 785a54b a8edea3 785a54b caff12e 2959d62 a8edea3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os
import rwkv_rs
import numpy as np
import huggingface_hub
import tokenizers
import gradio as gr
model_path = "./rnn.safetensors"
if not os.path.exists(model_path):
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors")
assert model_path is not None
model = rwkv_rs.Rwkv(model_path)
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
GT = [
gr.Button.update(visible=False),
gr.Button.update(visible=True),
]
GF = [
gr.Button.update(visible=True),
gr.Button.update(visible=False),
]
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
try:
state = rwkv_rs.State(model)
text = inpt
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(inpt).ids
yield (None, gr.Text.update(visible=False))
# yield ("Preproc...", gr.Text.update(visible=False))
# logits = model.forward(tokens, state)
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
# finally:
# return (None, None)
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
try:
if inpt.count("<|INSERT|>") != 1:
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
return
state = rwkv_rs.State(model)
text, end = inpt.split("<|INSERT|>")
counts = [0]*tokenizer.get_vocab_size()
tokens = tokenizer.encode(text).ids
tokens_end = tokenizer.encode(end).ids
tokens_i = tokens_end[:num_tokens_insert]
ins = [0]*len(tokens_i)
yield (None, gr.Text.update(visible=False))
for i in range(len(tokens) - 1):
model.forward_token_preproc(tokens[i], state)
yield (tokenizer.decode(tokens[:i + 1]), None)
logits = model.forward_token(tokens[-1], state)
yield (text, None)
max_tokens = int(max_tokens)
for i in range(max_tokens):
if i < min_tokens:
logits[0] = -100
for i in range(len(counts)):
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
token = np.argmax(logits)
counts[token] += 1
if token == 0:
break
tokens += [token]
ins = ins[1:] + [token]
if ins == tokens_i:
tokens += tokens_end[num_tokens_insert:]
i = max_tokens - 1 # to break earlier...
text = tokenizer.decode(tokens)
yield (text, None)
if i == max_tokens - 1:
break
logits = model.forward_token(token, state)
yield (text, None)
except Exception as e:
print(e)
yield ("Error...", gr.Text.update(value=str(e), visible=True))
def classify_fn_inner2(inpt, clas):
state = rwkv_rs.State(model)
tokens = tokenizer.encode(f"This is an example of {clas} text:").ids
for i in tokens:
model.forward_token_preproc(i, state)
tokens = tokenizer.encode(f" {inpt}\n").ids
loss = 0
for i in range(len(tokens)-1):
loss += np.log(softmax(model.forward_token(tokens[i], state)))[tokens[i+1]]
loss = -loss / (len(tokens)-1)
return loss
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum()
def classify_fn(inpt, clas, clasneg):
loss_3 = classify_fn_inner2(inpt, clas)
loss_3_neg = classify_fn_inner2(inpt, clasneg)
# print(loss_3, loss_3_neg, end=' | ')
loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg])
# print(loss_3, loss_3_neg)
return ({"+": loss_3, "-": loss_3_neg})
def generator_wrap(l, fn):
def wrap(*args):
last_i = list([None] * l)
try:
for i in fn(*args):
last_i = list(i)
yield last_i + GT
finally:
yield last_i + GF
return wrap
with gr.Blocks() as app:
gr.Markdown(f"Running on `{model_path}`")
error_box = gr.Text(label="Error", visible=False)
with gr.Tab("Complete"):
with gr.Row():
inpt = gr.TextArea(label="Input")
out = gr.TextArea(label="Output")
complete = gr.Button("Complete", variant="primary")
c_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Insert"):
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
with gr.Row():
inpt_i = gr.TextArea(label="Input")
out_i = gr.TextArea(label="Output")
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
insert = gr.Button("Insert", variant="primary")
i_stop = gr.Button("Stop", variant="stop", visible=False)
with gr.Tab("Classification W/O head"):
gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.")
with gr.Row():
inpt_c = gr.TextArea(label="Input")
out_c = gr.Label(label="Output")
clas = gr.Textbox(label="+ NL class/example to check against.")
clasneg = gr.Textbox(label="- NL class/example to check against.")
classify = gr.Button("Classify", variant="primary")
with gr.Column():
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1)
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c])
app.queue(concurrency_count=2)
app.launch() |