File size: 6,732 Bytes
93013c6
e603ef9
 
 
0a485e6
e603ef9
 
 
d9d3f4b
93013c6
e603ef9
d9d3f4b
fc973c2
e603ef9
d9d3f4b
 
 
e603ef9
 
e5a75a4
e603ef9
 
fc973c2
e603ef9
93013c6
6c94b18
d9d3f4b
0a485e6
e603ef9
6c94b18
e603ef9
 
0a485e6
e603ef9
 
d9d3f4b
6c94b18
e603ef9
 
0a485e6
 
e603ef9
d9d3f4b
e603ef9
 
d9d3f4b
 
 
e603ef9
 
 
 
 
6c94b18
e603ef9
 
 
d9d3f4b
e603ef9
 
 
 
 
 
 
0a485e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e603ef9
 
 
 
0a485e6
fc973c2
e603ef9
 
0a485e6
fc973c2
0a485e6
6c94b18
 
d9d3f4b
e603ef9
 
 
 
 
 
 
d9d3f4b
6c94b18
d9d3f4b
 
 
 
 
 
 
6c94b18
d9d3f4b
e603ef9
 
d9d3f4b
 
e603ef9
6c94b18
d9d3f4b
e603ef9
 
 
 
d9d3f4b
e603ef9
 
 
 
0a485e6
 
 
 
 
e603ef9
 
 
d9d3f4b
e603ef9
 
 
 
 
 
 
 
 
0a485e6
 
 
e603ef9
d9d3f4b
e603ef9
 
 
 
 
 
d9d3f4b
 
 
 
 
 
e603ef9
d9d3f4b
e603ef9
 
0a485e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast
import pandas as pd

def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]:
    config = DiffusionPipeline.load_config(model_id)
    num_tokenizers = sum("tokenizer" in key for key in config.keys())

    if not 1 <= num_tokenizers <= 3:
        raise gr.Error(f"Invalid number of tokenizers: {num_tokenizers}")

    tokenizers = [
        AutoTokenizer.from_pretrained(
            model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}'
        )
        for i in range(num_tokenizers)
    ]

    # Pad the list with None if there are fewer than 3 tokenizers
    tokenizers.extend([None] * (3 - num_tokenizers))

    return tokenizers


@torch.no_grad()
def inference(model_id: str, text: str):
    tokenizers = load_tokenizers(model_id)

    text_pairs_components = []
    special_tokens_components = []
    tokenizer_details_components = []
    for i, tokenizer in enumerate(tokenizers):
        if tokenizer:
            label_text = f"Tokenizer {i + 1}: {tokenizer.__class__.__name__}"

            # テキストとトークンIDのペアを作成
            input_ids = tokenizer(
                text=text,
                truncation=False,
                return_length=False,
                return_overflowing_tokens=False,
            ).input_ids
            decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids]
            token_pairs = [
                (str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)
            ]
            output_text_pair_component = gr.HighlightedText(
                label=label_text,
                value=token_pairs,
                visible=True,
            )

            # スペシャルトークンを追加
            special_tokens = []
            for k, v in tokenizer.special_tokens_map.items():
                if k == "additional_special_tokens":
                    continue
                special_token_map = (str(k), str(v))
                special_tokens.append(special_token_map)
            output_special_tokens_component = gr.HighlightedText(
                label=label_text,
                value=special_tokens,
                visible=True,
            )

            # トークナイザーの詳細情報を追加
            tokenizer_details = pd.DataFrame([
                ("Type", tokenizer.__class__.__name__),
                ("Vocab Size", tokenizer.vocab_size),
                ("Model Max Length", tokenizer.model_max_length),
                ("Padding Side", tokenizer.padding_side),
                ("Truncation Side", tokenizer.truncation_side),
            ], columns=["Attribute", "Value"])
            output_tokenizer_details = gr.Dataframe(
                headers=["Attribute", "Value"],
                value=tokenizer_details,
                label=label_text,
                visible=True,
            )
        else:
            output_text_pair_component = gr.HighlightedText(visible=False)
            output_special_tokens_component = gr.HighlightedText(visible=False)
            output_tokenizer_details = gr.Dataframe(visible=False)

        text_pairs_components.append(output_text_pair_component)
        special_tokens_components.append(output_special_tokens_component)
        tokenizer_details_components.append(output_tokenizer_details)

    return  text_pairs_components + special_tokens_components + tokenizer_details_components


if __name__ == "__main__":
    theme = gr.themes.Soft(
        primary_hue=gr.themes.colors.emerald,
        secondary_hue=gr.themes.colors.emerald,
    )
    with gr.Blocks(theme=theme) as demo:
        with gr.Column():
            input_model_id = gr.Dropdown(
                label="Model ID",
                choices=[
                    "black-forest-labs/FLUX.1-dev",
                    "black-forest-labs/FLUX.1-schnell",
                    "stabilityai/stable-diffusion-3-medium-diffusers",
                    "stabilityai/stable-diffusion-xl-base-1.0",
                    "stable-diffusion-v1-5/stable-diffusion-v1-5",
                    "stabilityai/japanese-stable-diffusion-xl",
                    "rinna/japanese-stable-diffusion",
                ],
                value="black-forest-labs/FLUX.1-dev",
            )
            input_text = gr.Textbox(
                label="Input Text",
                placeholder="Enter text here",
            )

            with gr.Tab(label="Tokenization Outputs"):
                with gr.Column():
                    output_highlighted_text_1 = gr.HighlightedText()
                    output_highlighted_text_2 = gr.HighlightedText()
                    output_highlighted_text_3 = gr.HighlightedText()
            with gr.Tab(label="Special Tokens"):
                with gr.Column():
                    output_special_tokens_1 = gr.HighlightedText()
                    output_special_tokens_2 = gr.HighlightedText()
                    output_special_tokens_3 = gr.HighlightedText()
            with gr.Tab(label="Tokenizer Details"):
                with gr.Column():
                    output_tokenizer_details_1 = gr.Dataframe(headers=["Attribute", "Value"])
                    output_tokenizer_details_2 = gr.Dataframe(headers=["Attribute", "Value"])
                    output_tokenizer_details_3 = gr.Dataframe(headers=["Attribute", "Value"])

            with gr.Row():
                clear_button = gr.ClearButton(components=[input_text])
                submit_button = gr.Button("Run", variant="primary")

        all_inputs = [input_model_id, input_text]
        all_output = [
            output_highlighted_text_1,
            output_highlighted_text_2,
            output_highlighted_text_3,
            output_special_tokens_1,
            output_special_tokens_2,
            output_special_tokens_3,
            output_tokenizer_details_1,
            output_tokenizer_details_2,
            output_tokenizer_details_3,
        ]
        submit_button.click(fn=inference, inputs=all_inputs, outputs=all_output)

        examples = gr.Examples(
            fn=inference,
            inputs=all_inputs,
            outputs=all_output,
            examples=[
                ["black-forest-labs/FLUX.1-dev", "a photo of cat"],
                [
                    "stabilityai/stable-diffusion-3-medium-diffusers",
                    'cat holding sign saying "I am a cat"',
                ],
                ["rinna/japanese-stable-diffusion", "空を飛んでいるネコの写真 油絵"],
            ],
            cache_examples=True,
        )

    demo.queue().launch()