import gradio as gr import torch from diffusers import DiffusionPipeline from transformers import AutoTokenizer, CLIPTokenizerFast, T5TokenizerFast def load_tokenizers(model_id: str) -> list[CLIPTokenizerFast | T5TokenizerFast | None]: config = DiffusionPipeline.load_config(model_id) num_tokenizers = sum('tokenizer' in key for key in config.keys()) if not 1 <= num_tokenizers <= 3: raise gr.Error(f'Invalid number of tokenizers: {num_tokenizers}') tokenizers = [ AutoTokenizer.from_pretrained(model_id, subfolder=f'tokenizer{"" if i == 0 else f"_{i + 1}"}') for i in range(num_tokenizers) ] # Pad the list with None if there are fewer than 3 tokenizers tokenizers.extend([None] * (3 - num_tokenizers)) return tokenizers @torch.inference_mode() def inference(model_id: str, input_text: str): tokenizers = load_tokenizers(model_id) text_pairs_components = [] special_tokens_components = [] for i, tokenizer in enumerate(tokenizers): if tokenizer: label_text = f'Tokenizer {i + 1}: {tokenizer.__class__.__name__}' # テキストとトークンIDのペアを作成 input_ids = tokenizer( text=input_text, truncation=True, return_length=False, return_overflowing_tokens=False ).input_ids decoded_tokens = [tokenizer.decode(id_) for id_ in input_ids] token_pairs = [(str(token), str(id_)) for token, id_ in zip(decoded_tokens, input_ids)] output_text_pair_component = gr.HighlightedText( label=label_text, value=token_pairs, visible=True, show_legend=True, ) # スペシャルトークンを追加 special_tokens = [] for k, v in tokenizer.special_tokens_map.items(): if k == 'additional_special_tokens': continue special_token_map = (str(k), str(v)) special_tokens.append(special_token_map) output_special_tokens_component = gr.HighlightedText( label=label_text, value=special_tokens, visible=True, show_legend=True, ) else: output_text_pair_component = gr.HighlightedText(visible=False) output_special_tokens_component = gr.HighlightedText(visible=False) text_pairs_components.append(output_text_pair_component) special_tokens_components.append(output_special_tokens_component) return text_pairs_components + special_tokens_components if __name__ == '__main__': theme = gr.themes.Soft( primary_hue=gr.themes.colors.emerald, secondary_hue=gr.themes.colors.emerald, ) with gr.Blocks(theme=theme) as demo: with gr.Column(): input_model_id = gr.Dropdown( label='Model ID', choices=[ 'black-forest-labs/FLUX.1-dev', 'black-forest-labs/FLUX.1-schnell', 'stabilityai/stable-diffusion-3-medium-diffusers', 'stabilityai/stable-diffusion-xl-base-1.0', 'stable-diffusion-v1-5/stable-diffusion-v1-5', 'stabilityai/japanese-stable-diffusion-xl', 'rinna/japanese-stable-diffusion', ], value='black-forest-labs/FLUX.1-dev', ) input_text = gr.Textbox( label='Input Text', placeholder='Enter text here', ) with gr.Tab(label='Tokenization Outputs'): with gr.Column(): output_highlighted_text_1 = gr.HighlightedText() output_highlighted_text_2 = gr.HighlightedText() output_highlighted_text_3 = gr.HighlightedText() with gr.Tab(label='Special Tokens'): with gr.Column(): output_special_tokens_1 = gr.HighlightedText() output_special_tokens_2 = gr.HighlightedText() output_special_tokens_3 = gr.HighlightedText() with gr.Row(): clear_button = gr.ClearButton(components=[input_text]) submit_button = gr.Button('Run', variant='primary') all_inputs = [input_model_id, input_text] all_output = [ output_highlighted_text_1, output_highlighted_text_2, output_highlighted_text_3, output_special_tokens_1, output_special_tokens_2, output_special_tokens_3, ] submit_button.click( fn=inference, inputs=all_inputs, outputs=all_output ) examples = gr.Examples( fn=inference, inputs=all_inputs, outputs=all_output, examples=[ ['black-forest-labs/FLUX.1-dev', 'a photo of cat'], ['stabilityai/stable-diffusion-3-medium-diffusers', 'cat holding sign saying "I am a cat"'], ['rinna/japanese-stable-diffusion', '空を飛んでいるネコの写真 油絵'] ], cache_examples=True ) demo.queue().launch()