Spaces:
Running
Running
File size: 5,425 Bytes
93013c6 e603ef9 d9d3f4b 93013c6 e603ef9 d9d3f4b fc973c2 e603ef9 d9d3f4b e603ef9 e5a75a4 e603ef9 fc973c2 e603ef9 93013c6 6c94b18 d9d3f4b e603ef9 6c94b18 e603ef9 d9d3f4b 6c94b18 e603ef9 d9d3f4b e603ef9 d9d3f4b e603ef9 6c94b18 e603ef9 d9d3f4b e603ef9 fc973c2 e603ef9 fc973c2 e603ef9 6c94b18 d9d3f4b e603ef9 d9d3f4b 6c94b18 d9d3f4b 6c94b18 d9d3f4b e603ef9 d9d3f4b e603ef9 6c94b18 d9d3f4b e603ef9 d9d3f4b e603ef9 d9d3f4b e603ef9 d9d3f4b e603ef9 d9d3f4b e603ef9 d9d3f4b e603ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast
def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
config = DiffusionPipeline.load_config(model_id)
num_tokenizers = sum("tokenizer" in key for key in config.keys())
if not 1 <= num_tokenizers <= 3:
raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")
tokenizers = [
AutoTokenizer.from_pretrained(
model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
)
for i in range(num_tokenizers)
]
# Pad the list with None if there are fewer than 3 tokenizers
tokenizers.extend([None] * (3 - num_tokenizers))
return tokenizers
@torch.no_grad()
def inference(model_id: str, input_text: str):
tokenizers = load_tokenizers(model_id)
text_pairs_components = []
special_tokens_components = []
for i, tokenizer in enumerate(tokenizers):
if tokenizer:
label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"
# テキストとトークンIDのペアを作成
input_ids = tokenizer(
text=input_text,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
).input_ids
decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
token_pairs = [
(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
]
output_text_pair_component = gr.HighlightedText(
label=label_text,
value=token_pairs,
visible=True,
show_legend=True,
)
# スペシャルトークンを追加
special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if k == "additional_special_tokens":
continue
special_token_map = (str(k), str(v))
special_tokens.append(special_token_map)
output_special_tokens_component = gr.HighlightedText(
label=label_text,
value=special_tokens,
visible=True,
show_legend=True,
)
else:
output_text_pair_component = gr.HighlightedText(visible=False)
output_special_tokens_component = gr.HighlightedText(visible=False)
text_pairs_components.append(output_text_pair_component)
special_tokens_components.append(output_special_tokens_component)
return text_pairs_components + special_tokens_components
if __name__ == "__main__":
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.emerald,
secondary_hue=gr.themes.colors.emerald,
)
with gr.Blocks(theme=theme) as demo:
with gr.Column():
input_model_id = gr.Dropdown(
label="Model ID",
choices=[
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-schnell",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"stabilityai/japanese-stable-diffusion-xl",
"rinna/japanese-stable-diffusion",
],
value="black-forest-labs/FLUX.1-dev",
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text here",
)
with gr.Tab(label="Tokenization Outputs"):
with gr.Column():
output_highlighted_text_1 = gr.HighlightedText()
output_highlighted_text_2 = gr.HighlightedText()
output_highlighted_text_3 = gr.HighlightedText()
with gr.Tab(label="Special Tokens"):
with gr.Column():
output_special_tokens_1 = gr.HighlightedText()
output_special_tokens_2 = gr.HighlightedText()
output_special_tokens_3 = gr.HighlightedText()
with gr.Row():
clear_button = gr.ClearButton(components=[input_text])
submit_button = gr.Button("Run", variant="primary")
all_inputs = [input_model_id, input_text]
all_output = [
output_highlighted_text_1,
output_highlighted_text_2,
output_highlighted_text_3,
output_special_tokens_1,
output_special_tokens_2,
output_special_tokens_3,
]
submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)
examples = gr.Examples(
fn=inference,
inputs=all_inputs,
outputs=all_output,
examples=[
["black-forest-labs/FLUX.1-dev", "a photo of cat"],
[
"stabilityai/stable-diffusion-3-medium-diffusers",
'cat holding sign saying "I am a cat"',
],
["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
],
cache_examples=True,
)
demo.queue().launch()
|