File size: 5,425 Bytes
93013c6
e603ef9
 
 
 
 
 
 
d9d3f4b
93013c6
e603ef9
d9d3f4b
fc973c2
e603ef9
d9d3f4b
 
 
e603ef9
 
e5a75a4
e603ef9
 
fc973c2
e603ef9
93013c6
6c94b18
d9d3f4b
e603ef9
 
6c94b18
e603ef9
 
 
 
d9d3f4b
6c94b18
e603ef9
 
 
 
 
d9d3f4b
e603ef9
 
d9d3f4b
 
 
e603ef9
 
 
 
 
 
6c94b18
e603ef9
 
 
d9d3f4b
e603ef9
 
 
 
 
 
 
 
 
 
 
 
fc973c2
e603ef9
 
fc973c2
e603ef9
6c94b18
 
d9d3f4b
e603ef9
 
 
 
 
 
 
d9d3f4b
6c94b18
d9d3f4b
 
 
 
 
 
 
6c94b18
d9d3f4b
e603ef9
 
d9d3f4b
 
e603ef9
6c94b18
d9d3f4b
e603ef9
 
 
 
d9d3f4b
e603ef9
 
 
 
 
 
 
d9d3f4b
e603ef9
 
 
 
 
 
 
 
 
 
d9d3f4b
e603ef9
 
 
 
 
 
d9d3f4b
 
 
 
 
 
e603ef9
d9d3f4b
e603ef9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast


def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
    config = DiffusionPipeline.load_config(model_id)
    num_tokenizers = sum("tokenizer" in key for key in config.keys())

    if not 1 <= num_tokenizers <= 3:
        raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")

    tokenizers = [
        AutoTokenizer.from_pretrained(
            model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
        )
        for i in range(num_tokenizers)
    ]

    # Pad the list with None if there are fewer than 3 tokenizers
    tokenizers.extend([None] * (3 - num_tokenizers))

    return tokenizers


@torch.no_grad()
def inference(model_id: str, input_text: str):
    tokenizers = load_tokenizers(model_id)

    text_pairs_components = []
    special_tokens_components = []
    for i, tokenizer in enumerate(tokenizers):
        if tokenizer:
            label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"

            # テキストとトークンIDのペアを作成
            input_ids = tokenizer(
                text=input_text,
                truncation=True,
                return_length=False,
                return_overflowing_tokens=False,
            ).input_ids
            decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
            token_pairs = [
                (str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
            ]
            output_text_pair_component = gr.HighlightedText(
                label=label_text,
                value=token_pairs,
                visible=True,
                show_legend=True,
            )

            # スペシャルトークンを追加
            special_tokens = []
            for k, v in tokenizer.special_tokens_map.items():
                if k == "additional_special_tokens":
                    continue
                special_token_map = (str(k), str(v))
                special_tokens.append(special_token_map)
            output_special_tokens_component = gr.HighlightedText(
                label=label_text,
                value=special_tokens,
                visible=True,
                show_legend=True,
            )
        else:
            output_text_pair_component = gr.HighlightedText(visible=False)
            output_special_tokens_component = gr.HighlightedText(visible=False)

        text_pairs_components.append(output_text_pair_component)
        special_tokens_components.append(output_special_tokens_component)

    return text_pairs_components + special_tokens_components


if __name__ == "__main__":
    theme = gr.themes.Soft(
        primary_hue=gr.themes.colors.emerald,
        secondary_hue=gr.themes.colors.emerald,
    )
    with gr.Blocks(theme=theme) as demo:
        with gr.Column():
            input_model_id = gr.Dropdown(
                label="Model ID",
                choices=[
                    "black-forest-labs/FLUX.1-dev",
                    "black-forest-labs/FLUX.1-schnell",
                    "stabilityai/stable-diffusion-3-medium-diffusers",
                    "stabilityai/stable-diffusion-xl-base-1.0",
                    "stable-diffusion-v1-5/stable-diffusion-v1-5",
                    "stabilityai/japanese-stable-diffusion-xl",
                    "rinna/japanese-stable-diffusion",
                ],
                value="black-forest-labs/FLUX.1-dev",
            )
            input_text = gr.Textbox(
                label="Input Text",
                placeholder="Enter text here",
            )

            with gr.Tab(label="Tokenization Outputs"):
                with gr.Column():
                    output_highlighted_text_1 = gr.HighlightedText()
                    output_highlighted_text_2 = gr.HighlightedText()
                    output_highlighted_text_3 = gr.HighlightedText()
            with gr.Tab(label="Special Tokens"):
                with gr.Column():
                    output_special_tokens_1 = gr.HighlightedText()
                    output_special_tokens_2 = gr.HighlightedText()
                    output_special_tokens_3 = gr.HighlightedText()

            with gr.Row():
                clear_button = gr.ClearButton(components=[input_text])
                submit_button = gr.Button("Run", variant="primary")

        all_inputs = [input_model_id, input_text]
        all_output = [
            output_highlighted_text_1,
            output_highlighted_text_2,
            output_highlighted_text_3,
            output_special_tokens_1,
            output_special_tokens_2,
            output_special_tokens_3,
        ]
        submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)

        examples = gr.Examples(
            fn=inference,
            inputs=all_inputs,
            outputs=all_output,
            examples=[
                ["black-forest-labs/FLUX.1-dev", "a photo of cat"],
                [
                    "stabilityai/stable-diffusion-3-medium-diffusers",
                    'cat holding sign saying "I am a cat"',
                ],
                ["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
            ],
            cache_examples=True,
        )

    demo.queue().launch()