Prgckwb commited on
Commit
d9d3f4b
1 Parent(s): e603ef9

:tada: init

Browse files
Files changed (1) hide show
  1. app.py +35 -32
app.py CHANGED
@@ -6,13 +6,15 @@ from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast
6
 
7
  def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
8
  config = DiffusionPipeline.load_config(model_id)
9
- num_tokenizers = sum('tokenizer' in key for key in config.keys())
10
 
11
  if not 1 <= num_tokenizers <= 3:
12
- raise gr.Error(f'Invalid number of tokenizers: {num_tokenizers}')
13
 
14
  tokenizers = [
15
- AutoTokenizer.from_pretrained(model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}')
 
 
16
  for i in range(num_tokenizers)
17
  ]
18
 
@@ -22,7 +24,7 @@ def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast |
22
  return tokenizers
23
 
24
 
25
- @torch.inference_mode()
26
  def inference(model_id: str, input_text: str):
27
  tokenizers = load_tokenizers(model_id)
28
 
@@ -30,17 +32,19 @@ def inference(model_id: str, input_text: str):
30
  special_tokens_components = []
31
  for i, tokenizer in enumerate(tokenizers):
32
  if tokenizer:
33
- label_text = f'Tokenizer {i + 1}: {tokenizer.__class__.__name__}'
34
 
35
  # テキストとトークンIDのペアを作成
36
  input_ids = tokenizer(
37
  text=input_text,
38
  truncation=True,
39
  return_length=False,
40
- return_overflowing_tokens=False
41
  ).input_ids
42
  decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
43
- token_pairs = [(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)]
 
 
44
  output_text_pair_component = gr.HighlightedText(
45
  label=label_text,
46
  value=token_pairs,
@@ -51,7 +55,7 @@ def inference(model_id: str, input_text: str):
51
  # スペシャルトークンを追加
52
  special_tokens = []
53
  for k, v in tokenizer.special_tokens_map.items():
54
- if k == 'additional_special_tokens':
55
  continue
56
  special_token_map = (str(k), str(v))
57
  special_tokens.append(special_token_map)
@@ -71,7 +75,7 @@ def inference(model_id: str, input_text: str):
71
  return text_pairs_components + special_tokens_components
72
 
73
 
74
- if __name__ == '__main__':
75
  theme = gr.themes.Soft(
76
  primary_hue=gr.themes.colors.emerald,
77
  secondary_hue=gr.themes.colors.emerald,
@@ -79,29 +83,29 @@ if __name__ == '__main__':
79
  with gr.Blocks(theme=theme) as demo:
80
  with gr.Column():
81
  input_model_id = gr.Dropdown(
82
- label='Model ID',
83
  choices=[
84
- 'black-forest-labs/FLUX.1-dev',
85
- 'black-forest-labs/FLUX.1-schnell',
86
- 'stabilityai/stable-diffusion-3-medium-diffusers',
87
- 'stabilityai/stable-diffusion-xl-base-1.0',
88
- 'stable-diffusion-v1-5/stable-diffusion-v1-5',
89
- 'stabilityai/japanese-stable-diffusion-xl',
90
- 'rinna/japanese-stable-diffusion',
91
  ],
92
- value='black-forest-labs/FLUX.1-dev',
93
  )
94
  input_text = gr.Textbox(
95
- label='Input Text',
96
- placeholder='Enter text here',
97
  )
98
 
99
- with gr.Tab(label='Tokenization Outputs'):
100
  with gr.Column():
101
  output_highlighted_text_1 = gr.HighlightedText()
102
  output_highlighted_text_2 = gr.HighlightedText()
103
  output_highlighted_text_3 = gr.HighlightedText()
104
- with gr.Tab(label='Special Tokens'):
105
  with gr.Column():
106
  output_special_tokens_1 = gr.HighlightedText()
107
  output_special_tokens_2 = gr.HighlightedText()
@@ -109,7 +113,7 @@ if __name__ == '__main__':
109
 
110
  with gr.Row():
111
  clear_button = gr.ClearButton(components=[input_text])
112
- submit_button = gr.Button('Run', variant='primary')
113
 
114
  all_inputs = [input_model_id, input_text]
115
  all_output = [
@@ -120,22 +124,21 @@ if __name__ == '__main__':
120
  output_special_tokens_2,
121
  output_special_tokens_3,
122
  ]
123
- submit_button.click(
124
- fn=inference,
125
- inputs=all_inputs,
126
- outputs=all_output
127
- )
128
 
129
  examples = gr.Examples(
130
  fn=inference,
131
  inputs=all_inputs,
132
  outputs=all_output,
133
  examples=[
134
- ['black-forest-labs/FLUX.1-dev', 'a photo of cat'],
135
- ['stabilityai/stable-diffusion-3-medium-diffusers', 'cat holding sign saying "I am a cat"'],
136
- ['rinna/japanese-stable-diffusion', '空を飛んでいるネコの写真 油絵']
 
 
 
137
  ],
138
- cache_examples=True
139
  )
140
 
141
  demo.queue().launch()
 
6
 
7
  def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
8
  config = DiffusionPipeline.load_config(model_id)
9
+ num_tokenizers = sum("tokenizer" in key for key in config.keys())
10
 
11
  if not 1 <= num_tokenizers <= 3:
12
+ raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")
13
 
14
  tokenizers = [
15
+ AutoTokenizer.from_pretrained(
16
+ model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
17
+ )
18
  for i in range(num_tokenizers)
19
  ]
20
 
 
24
  return tokenizers
25
 
26
 
27
+ @torch.no_grad()
28
  def inference(model_id: str, input_text: str):
29
  tokenizers = load_tokenizers(model_id)
30
 
 
32
  special_tokens_components = []
33
  for i, tokenizer in enumerate(tokenizers):
34
  if tokenizer:
35
+ label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"
36
 
37
  # テキストとトークンIDのペアを作成
38
  input_ids = tokenizer(
39
  text=input_text,
40
  truncation=True,
41
  return_length=False,
42
+ return_overflowing_tokens=False,
43
  ).input_ids
44
  decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
45
+ token_pairs = [
46
+ (str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
47
+ ]
48
  output_text_pair_component = gr.HighlightedText(
49
  label=label_text,
50
  value=token_pairs,
 
55
  # スペシャルトークンを追加
56
  special_tokens = []
57
  for k, v in tokenizer.special_tokens_map.items():
58
+ if k == "additional_special_tokens":
59
  continue
60
  special_token_map = (str(k), str(v))
61
  special_tokens.append(special_token_map)
 
75
  return text_pairs_components + special_tokens_components
76
 
77
 
78
+ if __name__ == "__main__":
79
  theme = gr.themes.Soft(
80
  primary_hue=gr.themes.colors.emerald,
81
  secondary_hue=gr.themes.colors.emerald,
 
83
  with gr.Blocks(theme=theme) as demo:
84
  with gr.Column():
85
  input_model_id = gr.Dropdown(
86
+ label="Model ID",
87
  choices=[
88
+ "black-forest-labs/FLUX.1-dev",
89
+ "black-forest-labs/FLUX.1-schnell",
90
+ "stabilityai/stable-diffusion-3-medium-diffusers",
91
+ "stabilityai/stable-diffusion-xl-base-1.0",
92
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
93
+ "stabilityai/japanese-stable-diffusion-xl",
94
+ "rinna/japanese-stable-diffusion",
95
  ],
96
+ value="black-forest-labs/FLUX.1-dev",
97
  )
98
  input_text = gr.Textbox(
99
+ label="Input Text",
100
+ placeholder="Enter text here",
101
  )
102
 
103
+ with gr.Tab(label="Tokenization Outputs"):
104
  with gr.Column():
105
  output_highlighted_text_1 = gr.HighlightedText()
106
  output_highlighted_text_2 = gr.HighlightedText()
107
  output_highlighted_text_3 = gr.HighlightedText()
108
+ with gr.Tab(label="Special Tokens"):
109
  with gr.Column():
110
  output_special_tokens_1 = gr.HighlightedText()
111
  output_special_tokens_2 = gr.HighlightedText()
 
113
 
114
  with gr.Row():
115
  clear_button = gr.ClearButton(components=[input_text])
116
+ submit_button = gr.Button("Run", variant="primary")
117
 
118
  all_inputs = [input_model_id, input_text]
119
  all_output = [
 
124
  output_special_tokens_2,
125
  output_special_tokens_3,
126
  ]
127
+ submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)
 
 
 
 
128
 
129
  examples = gr.Examples(
130
  fn=inference,
131
  inputs=all_inputs,
132
  outputs=all_output,
133
  examples=[
134
+ ["black-forest-labs/FLUX.1-dev", "a photo of cat"],
135
+ [
136
+ "stabilityai/stable-diffusion-3-medium-diffusers",
137
+ 'cat holding sign saying "I am a cat"',
138
+ ],
139
+ ["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
140
  ],
141
+ cache_examples=True,
142
  )
143
 
144
  demo.queue().launch()