|
import os |
|
|
|
import rwkv_rs |
|
import numpy as np |
|
import huggingface_hub |
|
import tokenizers |
|
|
|
import gradio as gr |
|
|
|
model_path = "./rnn.safetensors" |
|
if not os.path.exists(model_path): |
|
model_path = huggingface_hub.hf_hub_download(repo_id="mrsteyk/RWKV-LM-safetensors", filename="RWKV-4-Pile-7B-Instruct-test1-20230124.rnn.safetensors") |
|
assert model_path is not None |
|
|
|
model = rwkv_rs.Rwkv(model_path) |
|
tokenizer = tokenizers.Tokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
|
|
|
GT = [ |
|
gr.Button.update(visible=False), |
|
gr.Button.update(visible=True), |
|
] |
|
GF = [ |
|
gr.Button.update(visible=True), |
|
gr.Button.update(visible=False), |
|
] |
|
|
|
def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p): |
|
try: |
|
state = rwkv_rs.State(model) |
|
text = inpt |
|
counts = [0]*tokenizer.get_vocab_size() |
|
tokens = tokenizer.encode(inpt).ids |
|
yield (None, gr.Text.update(visible=False)) |
|
|
|
|
|
for i in range(len(tokens) - 1): |
|
model.forward_token_preproc(tokens[i], state) |
|
yield (tokenizer.decode(tokens[:i + 1]), None) |
|
logits = model.forward_token(tokens[-1], state) |
|
yield (text, None) |
|
max_tokens = int(max_tokens) |
|
for i in range(max_tokens): |
|
if i < min_tokens: |
|
logits[0] = -100 |
|
for i in range(len(counts)): |
|
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) |
|
token = np.argmax(logits) |
|
counts[token] += 1 |
|
if token == 0: |
|
break |
|
tokens += [token] |
|
text = tokenizer.decode(tokens) |
|
yield (text, None) |
|
if i == max_tokens - 1: |
|
break |
|
logits = model.forward_token(token, state) |
|
yield (text, None) |
|
except Exception as e: |
|
print(e) |
|
yield ("Error...", gr.Text.update(value=str(e), visible=True)) |
|
|
|
|
|
|
|
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert): |
|
try: |
|
if inpt.count("<|INSERT|>") != 1: |
|
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True)) |
|
return |
|
state = rwkv_rs.State(model) |
|
text, end = inpt.split("<|INSERT|>") |
|
counts = [0]*tokenizer.get_vocab_size() |
|
tokens = tokenizer.encode(text).ids |
|
tokens_end = tokenizer.encode(end).ids |
|
tokens_i = tokens_end[:num_tokens_insert] |
|
ins = [0]*len(tokens_i) |
|
yield (None, gr.Text.update(visible=False)) |
|
for i in range(len(tokens) - 1): |
|
model.forward_token_preproc(tokens[i], state) |
|
yield (tokenizer.decode(tokens[:i + 1]), None) |
|
logits = model.forward_token(tokens[-1], state) |
|
yield (text, None) |
|
max_tokens = int(max_tokens) |
|
for i in range(max_tokens): |
|
if i < min_tokens: |
|
logits[0] = -100 |
|
for i in range(len(counts)): |
|
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p) |
|
token = np.argmax(logits) |
|
counts[token] += 1 |
|
if token == 0: |
|
break |
|
tokens += [token] |
|
ins = ins[1:] + [token] |
|
if ins == tokens_i: |
|
tokens += tokens_end[num_tokens_insert:] |
|
i = max_tokens - 1 |
|
text = tokenizer.decode(tokens) |
|
yield (text, None) |
|
if i == max_tokens - 1: |
|
break |
|
logits = model.forward_token(token, state) |
|
yield (text, None) |
|
except Exception as e: |
|
print(e) |
|
yield ("Error...", gr.Text.update(value=str(e), visible=True)) |
|
|
|
def classify_fn_inner2(inpt, clas): |
|
state = rwkv_rs.State(model) |
|
tokens = tokenizer.encode(f"This is an example of {clas} text:").ids |
|
for i in tokens: |
|
model.forward_token_preproc(i, state) |
|
|
|
tokens = tokenizer.encode(f" {inpt}\n").ids |
|
loss = 0 |
|
for i in range(len(tokens)-1): |
|
loss += np.log(softmax(model.forward_token(tokens[i], state)))[tokens[i+1]] |
|
loss = -loss / (len(tokens)-1) |
|
|
|
return loss |
|
|
|
def softmax(x): |
|
e = np.exp(x - np.max(x)) |
|
return e / e.sum() |
|
|
|
def classify_fn(inpt, clas, clasneg): |
|
loss_3 = classify_fn_inner2(inpt, clas) |
|
loss_3_neg = classify_fn_inner2(inpt, clasneg) |
|
|
|
loss_3, loss_3_neg = softmax([-loss_3, -loss_3_neg]) |
|
|
|
|
|
return ({"+": loss_3, "-": loss_3_neg}) |
|
|
|
def generator_wrap(l, fn): |
|
def wrap(*args): |
|
last_i = list([None] * l) |
|
try: |
|
for i in fn(*args): |
|
last_i = list(i) |
|
yield last_i + GT |
|
finally: |
|
yield last_i + GF |
|
return wrap |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown(f"Running on `{model_path}`") |
|
error_box = gr.Text(label="Error", visible=False) |
|
|
|
with gr.Tab("Complete"): |
|
with gr.Row(): |
|
inpt = gr.TextArea(label="Input") |
|
out = gr.TextArea(label="Output") |
|
complete = gr.Button("Complete", variant="primary") |
|
c_stop = gr.Button("Stop", variant="stop", visible=False) |
|
with gr.Tab("Insert"): |
|
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated") |
|
with gr.Row(): |
|
inpt_i = gr.TextArea(label="Input") |
|
out_i = gr.TextArea(label="Output") |
|
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1) |
|
insert = gr.Button("Insert", variant="primary") |
|
i_stop = gr.Button("Stop", variant="stop", visible=False) |
|
with gr.Tab("Classification W/O head"): |
|
gr.Markdown("This is an experimental classification attempt based on [this Twitter post](https://twitter.com/aicrumb/status/1625239547268280321)\n\nSettings at the bottom do no affect this example.") |
|
with gr.Row(): |
|
inpt_c = gr.TextArea(label="Input") |
|
out_c = gr.Label(label="Output") |
|
clas = gr.Textbox(label="+ NL class/example to check against.") |
|
clasneg = gr.Textbox(label="- NL class/example to check against.") |
|
classify = gr.Button("Classify", variant="primary") |
|
|
|
with gr.Column(): |
|
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767) |
|
min_tokens = gr.Slider(label="Min Tokens", minimum=0, maximum=4096, step=1) |
|
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01) |
|
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01) |
|
|
|
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop]) |
|
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False) |
|
|
|
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop]) |
|
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False) |
|
|
|
classify.click(classify_fn, [inpt_c, clas, clasneg], [out_c]) |
|
|
|
app.queue(concurrency_count=2) |
|
app.launch() |