tokenvisor-sd / app.py
Prgckwb
:tada: init
d9d3f4b
raw
history blame
No virus
5.43 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast
def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
config = DiffusionPipeline.load_config(model_id)
num_tokenizers = sum("tokenizer" in key for key in config.keys())
if not 1 <= num_tokenizers <= 3:
raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")
tokenizers = [
AutoTokenizer.from_pretrained(
model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
)
for i in range(num_tokenizers)
]
# Pad the list with None if there are fewer than 3 tokenizers
tokenizers.extend([None] * (3 - num_tokenizers))
return tokenizers
@torch.no_grad()
def inference(model_id: str, input_text: str):
tokenizers = load_tokenizers(model_id)
text_pairs_components = []
special_tokens_components = []
for i, tokenizer in enumerate(tokenizers):
if tokenizer:
label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"
# テキストとトークンIDのペアを作成
input_ids = tokenizer(
text=input_text,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
).input_ids
decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
token_pairs = [
(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
]
output_text_pair_component = gr.HighlightedText(
label=label_text,
value=token_pairs,
visible=True,
show_legend=True,
)
# スペシャルトークンを追加
special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if k == "additional_special_tokens":
continue
special_token_map = (str(k), str(v))
special_tokens.append(special_token_map)
output_special_tokens_component = gr.HighlightedText(
label=label_text,
value=special_tokens,
visible=True,
show_legend=True,
)
else:
output_text_pair_component = gr.HighlightedText(visible=False)
output_special_tokens_component = gr.HighlightedText(visible=False)
text_pairs_components.append(output_text_pair_component)
special_tokens_components.append(output_special_tokens_component)
return text_pairs_components + special_tokens_components
if __name__ == "__main__":
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.emerald,
secondary_hue=gr.themes.colors.emerald,
)
with gr.Blocks(theme=theme) as demo:
with gr.Column():
input_model_id = gr.Dropdown(
label="Model ID",
choices=[
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-schnell",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"stabilityai/japanese-stable-diffusion-xl",
"rinna/japanese-stable-diffusion",
],
value="black-forest-labs/FLUX.1-dev",
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text here",
)
with gr.Tab(label="Tokenization Outputs"):
with gr.Column():
output_highlighted_text_1 = gr.HighlightedText()
output_highlighted_text_2 = gr.HighlightedText()
output_highlighted_text_3 = gr.HighlightedText()
with gr.Tab(label="Special Tokens"):
with gr.Column():
output_special_tokens_1 = gr.HighlightedText()
output_special_tokens_2 = gr.HighlightedText()
output_special_tokens_3 = gr.HighlightedText()
with gr.Row():
clear_button = gr.ClearButton(components=[input_text])
submit_button = gr.Button("Run", variant="primary")
all_inputs = [input_model_id, input_text]
all_output = [
output_highlighted_text_1,
output_highlighted_text_2,
output_highlighted_text_3,
output_special_tokens_1,
output_special_tokens_2,
output_special_tokens_3,
]
submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)
examples = gr.Examples(
fn=inference,
inputs=all_inputs,
outputs=all_output,
examples=[
["black-forest-labs/FLUX.1-dev", "a photo of cat"],
[
"stabilityai/stable-diffusion-3-medium-diffusers",
'cat holding sign saying "I am a cat"',
],
["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
],
cache_examples=True,
)
demo.queue().launch()