import gradio as gr from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images import numpy as np from PIL import Image from transformers import AutoConfig, AutoModelForCausalLM import torch ## # Code from deepseek-ai/Janus # Space from huggingface/twodgirl. def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 1, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int) #.cuda() for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = processor.pad_id inputs_embeds = model.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) #.cuda() pkv = None for i in range(image_token_num_per_image): outputs = model.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = model.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = model.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = model.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=1): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() def generate_image(prompt, width, height, # num_steps, guidance, seed): if seed > -1: generator = torch.Generator('cpu').manual_seed(seed) else: generator = None messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] text = processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, sft_format=processor.sft_format, system_prompt='') text = text + processor.image_start_tag input_ids = torch.LongTensor(processor.tokenizer.encode(prompt)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance) images = unpack(patches, width // 16 * 16, height // 16 * 16) return Image.fromarray(images[0]), seed, '' with gr.Blocks() as demo: with gr.Row(): with gr.Column(): prompt = gr.Textbox(label='Prompt', value='portrait, color, cinematic') width = gr.Slider(256, 1536, 256, step=16, label='Width') height = gr.Slider(256, 1536, 256, step=16, label='Height') guidance = gr.Slider(1.0, 10.0, 5, step=0.1, label='Guidance') seed = gr.Number(-1, precision=0, label='Seed (-1 for random)') generate_btn = gr.Button('Generate') with gr.Column(): output_image = gr.Image(label='Generated Image') seed_output = gr.Textbox(label='Used Seed') intermediate_output = gr.Gallery(label='Output', elem_id='gallery', visible=False) prompt.submit( fn=generate_image, inputs=[prompt, width, height, guidance, seed], outputs=[output_image, seed_output, intermediate_output], ) generate_btn.click( fn=generate_image, inputs=[prompt, width, height, guidance, seed], outputs=[output_image, seed_output, intermediate_output], ) if __name__ == '__main__': model_path = 'deepseek-ai/Janus-1.3B' processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) tokenizer = processor.tokenizer # model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path) config.language_config._attn_implementation = 'eager' model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) # model = model.to(torch.bfloat16).cuda() model = model.to(torch.float16) demo.launch()