File size: 7,040 Bytes
0bf81ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a61cb9
0bf81ba
 
 
 
 
 
5a61cb9
0bf81ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import math
import platform
import logging
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM
from PIL import Image
import numpy as np
from numpy import asarray
import gradio as gr
import clip

from arguments import get_args
from load import load_model_args, load_model
from utils import get_first_sentence


logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)


def prepare(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    log.info(f'Device: {device}')
    args = load_model_args(args)

    def load_style(args, checkpoint):
        model = AutoModelForCausalLM.from_pretrained(args.init_model)
        if checkpoint is not None and Path(checkpoint).is_file():
            log.info("joint model: loading pretrained style generator")
            state = torch.load(checkpoint, map_location=torch.device('cpu'))
            if 'global_step' in state:
                step = state['global_step']
                log.info(f'trained for {step} steps')
            weights = state['state_dict']
            key = 'model.'
            weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)}
            model.load_state_dict(weights)
        else:
            log.info("joint model: loading vanila gpt")
        return model

    log.info(f'loading models')
    joint_model = load_style(args, checkpoint=getattr(args, 'demo_joint_model_weight', 'None'))
    joint_model = joint_model.to(device)
    model = load_model(args, device)
    tokenizer = model.tokenizer
    log.info(f'loaded models ')

    class Inferer:
        def __init__(self, args, model, joint_model, tokenizer, device):
            self.args = args
            self.model = model
            self.joint_model = joint_model
            self.tokenizer = tokenizer
            self.device = device

            self.clip_model, self.clip_preprocess = clip.load(args.clip_model_type, device=device, jit=False)

        def infer_joint(self, batch, window_size=10, vanilla_length=20, sample=False, temperature=0.7, **kwargs):
            with torch.no_grad():
                rollouts = self.model.sample(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'],
                                        features=batch['features'], labels=None,
                                        max_len=self.args.response_length, sample=sample,
                                        no_repeat_ngram_size=self.args.infer_no_repeat_size,
                                        invalidate_eos=False)
                '''
                query = rollouts['query/input_ids']
                res = rollouts['response/input_ids']
                gen1 = torch.cat([query, res], dim=1)
                mask1 = torch.cat([rollouts['query/mask'], rollouts['response/mask']], dim=1)
                '''
                res = rollouts['response/text']
                query = rollouts['query/text']
                generations = [f'{q} {v.strip()}' for q, v in zip(query, res)]

                cur_length = self.args.response_length
                if vanilla_length > 0:
                    for i in range(math.ceil(vanilla_length / window_size)):
                        cur_length += window_size
                        generations = self.tokenizer(generations, padding=True, return_tensors='pt').to(self.device)
                        context = generations['input_ids'][:, :-window_size]
                        inputs = generations['input_ids'][:, -window_size:]
                        out = self.joint_model.generate(input_ids=inputs,
                                                max_length=cur_length, sample=sample,
                                                no_repeat_ngram_size=self.args.infer_no_repeat_size,
                                                        pad_token_id=self.tokenizer.eos_token_id)
                        out = torch.cat([context, out], dim=1)
                        text = [self.tokenizer.decode(v, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                                for v in out]
                        # generations = [get_first_sentence(v) for v in generations]
                        generations = text
                query = rollouts['query/text']
                del rollouts
            torch.cuda.empty_cache()
            return query, generations

        def get_feature(self, image):
            image = self.clip_preprocess(image).unsqueeze(0).to(self.device)
            feature = self.clip_model.encode_image(image)
            return feature

        def __call__(self, image, prompt, length=20, window_size=20, **kwargs):
            window_size = min(window_size, length)
            vanilla_length = max(0, length - self.args.response_length)
            if not prompt:
                prompt = 'The'
            feature = self.get_feature(image)
            feature = feature.unsqueeze(0).to(self.device)
            batch = self.tokenizer(prompt, padding=True, return_tensors='pt').to(self.device)
            batch['features'] = feature
            query, generations = self.infer_joint(batch, window_size=window_size,
                                                  vanilla_length=vanilla_length, **kwargs)
            # text = f'{query[0].strip()} {generations[0].strip()}'
            text = generations[0].strip()
            return text

    inferer = Inferer(args, model, joint_model, tokenizer, device)
    return inferer


class Runner:
    def __init__(self, inferer):
        self.inferer = inferer

    def __call__(self, inp, prompt, length, window_size, sample):
        # inp = inp.reshape((224, 224, 3))
        img = Image.fromarray(np.uint8(inp))
        text = self.inferer(img, prompt, length, window_size, sample=sample)
        return prompt, text
        # return inp, prompt, text


'''
# test_run
sample_img = asarray(Image.open('../data/coco/images/sample.jpg'))
img, _, text = run(sample_img, 'There lies', 50, 20, sample=False)
print('test_run:', text)
'''

def launch(examples=None):
    args = get_args()
    inferer = prepare(args)
    runner = Runner(inferer)

    iface = gr.Interface(
        title="Demo for ESPER",
        fn=runner.__call__,
        inputs=[gr.components.Image(shape=(224, 224)),
                gr.components.Textbox(label='prompt'),
                gr.components.Slider(20, 120, step=1, label='length'),
                gr.components.Slider(10, 100, step=1, label='window_size'),
                gr.components.Checkbox(label='do sample')],
        outputs=[gr.components.Textbox(label='prompt'),
                gr.components.Textbox(label='generation')],
        examples=examples
    )
    if args.port is not None:
        print(f"running from {platform.node()}")
        iface.launch(
            server_name="0.0.0.0",
            server_port=args.port
        )
    else:
        iface.launch()


if __name__ == "__main__":
    print(f"running from {platform.node()}")
    launch()