File size: 4,443 Bytes
bdf9962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import PIL.Image
import torch
import numpy as np
from janus.utils.io import load_pil_images
from janus.models import MultiModalityCausalLM, VLChatProcessor
from functools import lru_cache


def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
    uncond_input_ids = torch.full((1, input_embeds.shape[1]), 
                                  vl_chat_processor.pad_id, 
                                  dtype=torch.long, 
                                  device=input_embeds.device)
    uncond_input_ids[:, 0] = input_embeds.shape[1] - 1
    uncond_input_ids[:, -1] = vl_chat_processor.tokenizer.eos_token_id
    
    uncond_input_embeds = mmgpt.language_model.get_input_embeddings()(uncond_input_ids)
    uncond_input_embeds[:, -1, :] = input_embeds[:, -1, :]

    cond_input_embeds = input_embeds.repeat(batch_size, 1, 1)
    uncond_input_embeds = uncond_input_embeds.repeat(batch_size, 1, 1)
    
    combined_input_embeds = torch.stack([cond_input_embeds, uncond_input_embeds], dim=1)
    combined_input_embeds = combined_input_embeds.view(batch_size * 2, -1, input_embeds.shape[-1])
    
    return combined_input_embeds

@spaces.GPU
@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    inputs_embeds,
    temperature: float = 1,
    parallel_size: int = 1,
    cfg_weight: float = 5,
    image_token_num_per_image: int = 576,
    img_size: int = 384,
    patch_size: int = 16,
):
    generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()

    inputs_embeds = prepare_classifier_free_guidance_input(inputs_embeds, vl_chat_processor, mmgpt, parallel_size)

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
        hidden_states = outputs.last_hidden_state
        
        logits = mmgpt.gen_head(hidden_states[:, -1, :])

        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        
        logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)

        next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
    dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)

    dec = np.clip((dec + 1) / 2 * 255, 0, 255)

    visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
    visual_img[:, :, :] = dec

    generated_images = []
    for i in range(parallel_size):
        generated_images.append(PIL.Image.fromarray(visual_img[i]))

    return generated_images

@lru_cache(maxsize=1)
def get_start_tag_embed(vl_gpt, vl_chat_processor):
    with torch.no_grad():
        return vl_gpt.language_model.get_input_embeddings()(
            vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
        )
    
def process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images=4, cfg_weight=5):
    start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)

    nl = '\n'
    conversation = [
        {
            "role": "User",
            "content": f"<image_placeholder>{nl + prompt if prompt else ''}",
            "images": [input_image],
        },
        {"role": "Assistant", "content": ""},
    ]

    pil_images = load_pil_images(conversation)
    prepare_inputs = vl_chat_processor(
        conversations=conversation, images=pil_images, force_batchify=True
    ).to(vl_gpt.device)

    with torch.no_grad():
        inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

    inputs_embeds = torch.cat((inputs_embeds, start_tag_embed), dim=1)

    generated_images = generate(
        vl_gpt,
        vl_chat_processor,
        inputs_embeds,
        parallel_size=num_images,
        cfg_weight=cfg_weight
    )

    return generated_images