Alexandr "MrSteyk" German commited on
Commit
785a54b
1 Parent(s): 2959d62
Files changed (1) hide show
  1. app.py +55 -8
app.py CHANGED
@@ -30,6 +30,7 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
30
  text = inpt
31
  counts = [0]*tokenizer.get_vocab_size()
32
  tokens = tokenizer.encode(inpt).ids
 
33
  # yield ("Preproc...", gr.Text.update(visible=False))
34
  # logits = model.forward(tokens, state)
35
  for i in range(len(tokens) - 1):
@@ -47,11 +48,11 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
47
  counts[token] += 1
48
  if token == 0:
49
  break
50
- if i == max_tokens - 1:
51
- break
52
  tokens += [token]
53
  text = tokenizer.decode(tokens)
54
  yield (text, None)
 
 
55
  logits = model.forward_token(token, state)
56
  yield (text, None)
57
  except Exception as e:
@@ -60,6 +61,49 @@ def complete_fn(inpt, max_tokens, min_tokens, alpha_f, alpha_p):
60
  # finally:
61
  # return (None, None)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def generator_wrap(l, fn):
64
  def wrap(*args):
65
  last_i = list([None] * l)
@@ -82,12 +126,14 @@ with gr.Blocks() as app:
82
  out = gr.TextArea(label="Output")
83
  complete = gr.Button("Complete", variant="primary")
84
  c_stop = gr.Button("Stop", variant="stop", visible=False)
85
- with gr.Tab("Insert (WIP)"):
86
- gr.Markdown("WIP, use `<|INSERT|>` to indicate a place to replace")
87
  with gr.Row():
88
  inpt_i = gr.TextArea(label="Input")
89
  out_i = gr.TextArea(label="Output")
90
- insert = gr.Button("Insert")
 
 
91
 
92
  with gr.Column():
93
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
@@ -95,10 +141,11 @@ with gr.Blocks() as app:
95
  alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
96
  alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
97
 
98
- G = [complete, c_stop]
99
-
100
- c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box] + G)
101
  c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
102
 
 
 
 
103
  app.queue(concurrency_count=2)
104
  app.launch()
 
30
  text = inpt
31
  counts = [0]*tokenizer.get_vocab_size()
32
  tokens = tokenizer.encode(inpt).ids
33
+ yield (None, gr.Text.update(visible=False))
34
  # yield ("Preproc...", gr.Text.update(visible=False))
35
  # logits = model.forward(tokens, state)
36
  for i in range(len(tokens) - 1):
 
48
  counts[token] += 1
49
  if token == 0:
50
  break
 
 
51
  tokens += [token]
52
  text = tokenizer.decode(tokens)
53
  yield (text, None)
54
+ if i == max_tokens - 1:
55
+ break
56
  logits = model.forward_token(token, state)
57
  yield (text, None)
58
  except Exception as e:
 
61
  # finally:
62
  # return (None, None)
63
 
64
+ def insert_fn(inpt: str, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert):
65
+ try:
66
+ if inpt.count("<|INSERT|>") != 1:
67
+ yield ("Error...", gr.Text.update(value="Exactly one replace is allowed!", visible=True))
68
+ return
69
+ state = rwkv_rs.State(model)
70
+ text, end = inpt.split("<|INSERT|>")
71
+ counts = [0]*tokenizer.get_vocab_size()
72
+ tokens = tokenizer.encode(text).ids
73
+ tokens_end = tokenizer.encode(end).ids
74
+ tokens_i = tokens_end[:num_tokens_insert]
75
+ ins = [0]*len(tokens_i)
76
+ yield (None, gr.Text.update(visible=False))
77
+ for i in range(len(tokens) - 1):
78
+ model.forward_token_preproc(tokens[i], state)
79
+ yield (tokenizer.decode(tokens[:i + 1]), None)
80
+ logits = model.forward_token(tokens[-1], state)
81
+ yield (text, None)
82
+ max_tokens = int(max_tokens)
83
+ for i in range(max_tokens):
84
+ if i < min_tokens:
85
+ logits[0] = -100
86
+ for i in range(len(counts)):
87
+ logits[i] -= (counts[i]* alpha_f) + (float(counts[i] > 0) * alpha_p)
88
+ token = np.argmax(logits)
89
+ counts[token] += 1
90
+ if token == 0:
91
+ break
92
+ tokens += [token]
93
+ ins = ins[1:] + [token]
94
+ if ins == tokens_i:
95
+ tokens += tokens_end[num_tokens_insert:]
96
+ i = max_tokens - 1 # to break earlier...
97
+ text = tokenizer.decode(tokens)
98
+ yield (text, None)
99
+ if i == max_tokens - 1:
100
+ break
101
+ logits = model.forward_token(token, state)
102
+ yield (text, None)
103
+ except Exception as e:
104
+ print(e)
105
+ yield ("Error...", gr.Text.update(value=str(e), visible=True))
106
+
107
  def generator_wrap(l, fn):
108
  def wrap(*args):
109
  last_i = list([None] * l)
 
126
  out = gr.TextArea(label="Output")
127
  complete = gr.Button("Complete", variant="primary")
128
  c_stop = gr.Button("Stop", variant="stop", visible=False)
129
+ with gr.Tab("Insert"):
130
+ gr.Markdown("Use `<|INSERT|>` to indicate a place to replace, if insert fails - end text won't be concatenated")
131
  with gr.Row():
132
  inpt_i = gr.TextArea(label="Input")
133
  out_i = gr.TextArea(label="Output")
134
+ num_tokens_insert = gr.Slider(label="Number of tokens to compare for ending (from the beginning of 2nd part)", minimum=1, maximum=2048, value=1024, step=1)
135
+ insert = gr.Button("Insert", variant="submit")
136
+ i_stop = gr.Button("Stop", variant="stop", visible=False)
137
 
138
  with gr.Column():
139
  max_tokens = gr.Slider(label="Max Tokens", minimum=1, maximum=4096, step=1, value=767)
 
141
  alpha_f = gr.Slider(label="Alpha Frequency", minimum=0, maximum=100, step=0.01)
142
  alpha_p = gr.Slider(label="Alpha Presence", minimum=0, maximum=100, step=0.01)
143
 
144
+ c = complete.click(generator_wrap(2, complete_fn), [inpt, max_tokens, min_tokens, alpha_f, alpha_p], [out, error_box, complete, c_stop])
 
 
145
  c_stop.click(lambda: (complete.update(visible=True), c_stop.update(visible=False)), inputs=None, outputs=[complete, c_stop], cancels=[c], queue=False)
146
 
147
+ i = insert.click(generator_wrap(2, insert_fn), [inpt_i, max_tokens, min_tokens, alpha_f, alpha_p, num_tokens_insert], [out_i, error_box, insert, i_stop])
148
+ i_stop.click(lambda: (insert.update(visible=True), i_stop.update(visible=False)), inputs=None, outputs=[insert, i_stop], cancels=[i], queue=False)
149
+
150
  app.queue(concurrency_count=2)
151
  app.launch()