tokenvisor-sd / app.py
Prgckwb
:tada: init
0a485e6
raw
history blame contribute delete
No virus
6.73 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast
import pandas as pd
def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
config = DiffusionPipeline.load_config(model_id)
num_tokenizers = sum("tokenizer" in key for key in config.keys())
if not 1 <= num_tokenizers <= 3:
raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")
tokenizers = [
AutoTokenizer.from_pretrained(
model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
)
for i in range(num_tokenizers)
]
# Pad the list with None if there are fewer than 3 tokenizers
tokenizers.extend([None] * (3 - num_tokenizers))
return tokenizers
@torch.no_grad()
def inference(model_id: str, text: str):
tokenizers = load_tokenizers(model_id)
text_pairs_components = []
special_tokens_components = []
tokenizer_details_components = []
for i, tokenizer in enumerate(tokenizers):
if tokenizer:
label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"
# テキストとトークンIDのペアを作成
input_ids = tokenizer(
text=text,
truncation=False,
return_length=False,
return_overflowing_tokens=False,
).input_ids
decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
token_pairs = [
(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
]
output_text_pair_component = gr.HighlightedText(
label=label_text,
value=token_pairs,
visible=True,
)
# スペシャルトークンを追加
special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if k == "additional_special_tokens":
continue
special_token_map = (str(k), str(v))
special_tokens.append(special_token_map)
output_special_tokens_component = gr.HighlightedText(
label=label_text,
value=special_tokens,
visible=True,
)
# トークナイザーの詳細情報を追加
tokenizer_details = pd.DataFrame([
("Type", tokenizer.__class__.__name__),
("Vocab Size", tokenizer.vocab_size),
("Model Max Length", tokenizer.model_max_length),
("Padding Side", tokenizer.padding_side),
("Truncation Side", tokenizer.truncation_side),
], columns=["Attribute", "Value"])
output_tokenizer_details = gr.Dataframe(
headers=["Attribute", "Value"],
value=tokenizer_details,
label=label_text,
visible=True,
)
else:
output_text_pair_component = gr.HighlightedText(visible=False)
output_special_tokens_component = gr.HighlightedText(visible=False)
output_tokenizer_details = gr.Dataframe(visible=False)
text_pairs_components.append(output_text_pair_component)
special_tokens_components.append(output_special_tokens_component)
tokenizer_details_components.append(output_tokenizer_details)
return text_pairs_components + special_tokens_components + tokenizer_details_components
if __name__ == "__main__":
theme = gr.themes.Soft(
primary_hue=gr.themes.colors.emerald,
secondary_hue=gr.themes.colors.emerald,
)
with gr.Blocks(theme=theme) as demo:
with gr.Column():
input_model_id = gr.Dropdown(
label="Model ID",
choices=[
"black-forest-labs/FLUX.1-dev",
"black-forest-labs/FLUX.1-schnell",
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"stabilityai/japanese-stable-diffusion-xl",
"rinna/japanese-stable-diffusion",
],
value="black-forest-labs/FLUX.1-dev",
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text here",
)
with gr.Tab(label="Tokenization Outputs"):
with gr.Column():
output_highlighted_text_1 = gr.HighlightedText()
output_highlighted_text_2 = gr.HighlightedText()
output_highlighted_text_3 = gr.HighlightedText()
with gr.Tab(label="Special Tokens"):
with gr.Column():
output_special_tokens_1 = gr.HighlightedText()
output_special_tokens_2 = gr.HighlightedText()
output_special_tokens_3 = gr.HighlightedText()
with gr.Tab(label="Tokenizer Details"):
with gr.Column():
output_tokenizer_details_1 = gr.Dataframe(headers=["Attribute", "Value"])
output_tokenizer_details_2 = gr.Dataframe(headers=["Attribute", "Value"])
output_tokenizer_details_3 = gr.Dataframe(headers=["Attribute", "Value"])
with gr.Row():
clear_button = gr.ClearButton(components=[input_text])
submit_button = gr.Button("Run", variant="primary")
all_inputs = [input_model_id, input_text]
all_output = [
output_highlighted_text_1,
output_highlighted_text_2,
output_highlighted_text_3,
output_special_tokens_1,
output_special_tokens_2,
output_special_tokens_3,
output_tokenizer_details_1,
output_tokenizer_details_2,
output_tokenizer_details_3,
]
submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)
examples = gr.Examples(
fn=inference,
inputs=all_inputs,
outputs=all_output,
examples=[
["black-forest-labs/FLUX.1-dev", "a photo of cat"],
[
"stabilityai/stable-diffusion-3-medium-diffusers",
'cat holding sign saying "I am a cat"',
],
["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
],
cache_examples=True,
)
demo.queue().launch()