Alexandr "MrSteyk" German
commited on
Commit
•
785a54b
1
Parent(s):
2959d62
insert
Browse files
app.py
CHANGED
@@ -30,6 +30,7 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
|
|
30 |
text = inpt
|
31 |
counts = [0]*tokenizer.get_vocab_size()
|
32 |
tokens = tokenizer.encode(inpt).ids
|
|
|
33 |
# yield ("Preproc...", gr.Text.update(visible=False))
|
34 |
# logits = model.forward(tokens, state)
|
35 |
for i in range(len(tokens) - 1):
|
@@ -47,11 +48,11 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
|
|
47 |
counts[token] += 1
|
48 |
if token == 0:
|
49 |
break
|
50 |
-
if i == max_tokens - 1:
|
51 |
-
break
|
52 |
tokens += [token]
|
53 |
text = tokenizer.decode(tokens)
|
54 |
yield (text, None)
|
|
|
|
|
55 |
logits = model.forward_token(token, state)
|
56 |
yield (text, None)
|
57 |
except Exception as e:
|
@@ -60,6 +61,49 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
|
|
60 |
# finally:
|
61 |
# return (None, None)
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def generator_wrap(l, fn):
|
64 |
def wrap(*args):
|
65 |
last_i = list([None] * l)
|
@@ -82,12 +126,14 @@ with gr.Blocks() as app:
|
|
82 |
out = gr.TextArea(label="Output")
|
83 |
complete = gr.Button("Complete", variant="primary")
|
84 |
c_stop = gr.Button("Stop", variant="stop", visible=False)
|
85 |
-
with gr.Tab("Insert
|
86 |
-
gr.Markdown("
|
87 |
with gr.Row():
|
88 |
inpt_i = gr.TextArea(label="Input")
|
89 |
out_i = gr.TextArea(label="Output")
|
90 |
-
|
|
|
|
|
91 |
|
92 |
with gr.Column():
|
93 |
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
|
@@ -95,10 +141,11 @@ with gr.Blocks() as app:
|
|
95 |
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
|
96 |
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box] + G)
|
101 |
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
|
102 |
|
|
|
|
|
|
|
103 |
app.queue(concurrency_count=2)
|
104 |
app.launch()
|
|
|
30 |
text = inpt
|
31 |
counts = [0]*tokenizer.get_vocab_size()
|
32 |
tokens = tokenizer.encode(inpt).ids
|
33 |
+
yield (None, gr.Text.update(visible=False))
|
34 |
# yield ("Preproc...", gr.Text.update(visible=False))
|
35 |
# logits = model.forward(tokens, state)
|
36 |
for i in range(len(tokens) - 1):
|
|
|
48 |
counts[token] += 1
|
49 |
if token == 0:
|
50 |
break
|
|
|
|
|
51 |
tokens += [token]
|
52 |
text = tokenizer.decode(tokens)
|
53 |
yield (text, None)
|
54 |
+
if i == max_tokens - 1:
|
55 |
+
break
|
56 |
logits = model.forward_token(token, state)
|
57 |
yield (text, None)
|
58 |
except Exception as e:
|
|
|
61 |
# finally:
|
62 |
# return (None, None)
|
63 |
|
64 |
+
def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
|
65 |
+
try:
|
66 |
+
if inpt.count("<|INSERT|>") != 1:
|
67 |
+
yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
|
68 |
+
return
|
69 |
+
state = rwkv_rs.State(model)
|
70 |
+
text, end = inpt.split("<|INSERT|>")
|
71 |
+
counts = [0]*tokenizer.get_vocab_size()
|
72 |
+
tokens = tokenizer.encode(text).ids
|
73 |
+
tokens_end = tokenizer.encode(end).ids
|
74 |
+
tokens_i = tokens_end[:num_tokens_insert]
|
75 |
+
ins = [0]*len(tokens_i)
|
76 |
+
yield (None, gr.Text.update(visible=False))
|
77 |
+
for i in range(len(tokens) - 1):
|
78 |
+
model.forward_token_preproc(tokens[i], state)
|
79 |
+
yield (tokenizer.decode(tokens[:i + 1]), None)
|
80 |
+
logits = model.forward_token(tokens[-1], state)
|
81 |
+
yield (text, None)
|
82 |
+
max_tokens = int(max_tokens)
|
83 |
+
for i in range(max_tokens):
|
84 |
+
if i < min_tokens:
|
85 |
+
logits[0] = -100
|
86 |
+
for i in range(len(counts)):
|
87 |
+
logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
|
88 |
+
token = np.argmax(logits)
|
89 |
+
counts[token] += 1
|
90 |
+
if token == 0:
|
91 |
+
break
|
92 |
+
tokens += [token]
|
93 |
+
ins = ins[1:] + [token]
|
94 |
+
if ins == tokens_i:
|
95 |
+
tokens += tokens_end[num_tokens_insert:]
|
96 |
+
i = max_tokens - 1 # to break earlier...
|
97 |
+
text = tokenizer.decode(tokens)
|
98 |
+
yield (text, None)
|
99 |
+
if i == max_tokens - 1:
|
100 |
+
break
|
101 |
+
logits = model.forward_token(token, state)
|
102 |
+
yield (text, None)
|
103 |
+
except Exception as e:
|
104 |
+
print(e)
|
105 |
+
yield ("Error...", gr.Text.update(value=str(e), visible=True))
|
106 |
+
|
107 |
def generator_wrap(l, fn):
|
108 |
def wrap(*args):
|
109 |
last_i = list([None] * l)
|
|
|
126 |
out = gr.TextArea(label="Output")
|
127 |
complete = gr.Button("Complete", variant="primary")
|
128 |
c_stop = gr.Button("Stop", variant="stop", visible=False)
|
129 |
+
with gr.Tab("Insert"):
|
130 |
+
gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
|
131 |
with gr.Row():
|
132 |
inpt_i = gr.TextArea(label="Input")
|
133 |
out_i = gr.TextArea(label="Output")
|
134 |
+
num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
|
135 |
+
insert = gr.Button("Insert", variant="submit")
|
136 |
+
i_stop = gr.Button("Stop", variant="stop", visible=False)
|
137 |
|
138 |
with gr.Column():
|
139 |
max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
|
|
|
141 |
alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
|
142 |
alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
|
143 |
|
144 |
+
c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
|
|
|
|
|
145 |
c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
|
146 |
|
147 |
+
i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
|
148 |
+
i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
|
149 |
+
|
150 |
app.queue(concurrency_count=2)
|
151 |
app.launch()
|