Spaces:
Runtime error
Runtime error
import os | |
import math | |
import copy | |
import platform | |
import logging | |
from pathlib import Path | |
from itertools import chain | |
import torch | |
from transformers import AutoModelForCausalLM | |
from PIL import Image | |
import numpy as np | |
from numpy import asarray | |
import gradio as gr | |
import clip | |
from arguments import get_args | |
from load import load_model_args, load_model | |
from utils import get_first_sentence | |
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
log = logging.getLogger(__name__) | |
def prepare(args): | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
log.info(f'Device: {device}') | |
args = load_model_args(args) | |
def load_style(args, checkpoint): | |
model = AutoModelForCausalLM.from_pretrained(args.init_model) | |
if checkpoint is not None and Path(checkpoint).is_file(): | |
log.info("joint model: loading pretrained style generator") | |
state = torch.load(checkpoint, map_location=torch.device('cpu')) | |
if 'global_step' in state: | |
step = state['global_step'] | |
log.info(f'trained for {step} steps') | |
weights = state['state_dict'] | |
key = 'model.' | |
weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)} | |
model.load_state_dict(weights) | |
else: | |
log.info("joint model: loading vanila gpt") | |
return model | |
log.info(f'loading models') | |
joint_model = load_style(args, checkpoint=getattr(args, 'demo_joint_model_weight', 'None')) | |
joint_model = joint_model.to(device) | |
model = load_model(args, device) | |
tokenizer = model.tokenizer | |
log.info(f'loaded models ') | |
class Inferer: | |
def __init__(self, args, model, joint_model, tokenizer, device): | |
self.args = args | |
self.model = model | |
self.joint_model = joint_model | |
self.tokenizer = tokenizer | |
self.device = device | |
self.clip_model, self.clip_preprocess = clip.load(args.clip_model_type, device=device, jit=False) | |
def infer_joint(self, batch, window_size=10, vanilla_length=20, sample=False, temperature=0.7, **kwargs): | |
with torch.no_grad(): | |
rollouts = self.model.sample(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], | |
features=batch['features'], labels=None, | |
max_len=self.args.response_length, sample=sample, | |
no_repeat_ngram_size=self.args.infer_no_repeat_size, | |
invalidate_eos=False) | |
''' | |
query = rollouts['query/input_ids'] | |
res = rollouts['response/input_ids'] | |
gen1 = torch.cat([query, res], dim=1) | |
mask1 = torch.cat([rollouts['query/mask'], rollouts['response/mask']], dim=1) | |
''' | |
res = rollouts['response/text'] | |
query = rollouts['query/text'] | |
generations = [f'{q} {v.strip()}' for q, v in zip(query, res)] | |
cur_length = self.args.response_length | |
if vanilla_length > 0: | |
for i in range(math.ceil(vanilla_length / window_size)): | |
cur_length += window_size | |
generations = self.tokenizer(generations, padding=True, return_tensors='pt').to(self.device) | |
context = generations['input_ids'][:, :-window_size] | |
inputs = generations['input_ids'][:, -window_size:] | |
out = self.joint_model.generate(input_ids=inputs, | |
max_length=cur_length, sample=sample, | |
no_repeat_ngram_size=self.args.infer_no_repeat_size, | |
pad_token_id=self.tokenizer.eos_token_id) | |
out = torch.cat([context, out], dim=1) | |
text = [self.tokenizer.decode(v, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
for v in out] | |
# generations = [get_first_sentence(v) for v in generations] | |
generations = text | |
query = rollouts['query/text'] | |
del rollouts | |
torch.cuda.empty_cache() | |
return query, generations | |
def get_feature(self, image): | |
image = self.clip_preprocess(image).unsqueeze(0).to(self.device) | |
feature = self.clip_model.encode_image(image) | |
return feature | |
def __call__(self, image, prompt, length=20, window_size=20, **kwargs): | |
window_size = min(window_size, length) | |
vanilla_length = max(0, length - self.args.response_length) | |
if not prompt: | |
prompt = 'The' | |
feature = self.get_feature(image) | |
feature = feature.unsqueeze(0).to(self.device) | |
batch = self.tokenizer(prompt, padding=True, return_tensors='pt').to(self.device) | |
batch['features'] = feature | |
query, generations = self.infer_joint(batch, window_size=window_size, | |
vanilla_length=vanilla_length, **kwargs) | |
# text = f'{query[0].strip()} {generations[0].strip()}' | |
text = generations[0].strip() | |
return text | |
inferer = Inferer(args, model, joint_model, tokenizer, device) | |
return inferer | |
class Runner: | |
def __init__(self, inferers): | |
self.inferers = inferers | |
def __call__(self, model_name, inp, prompt, length, sample): | |
inferer = self.inferers[model_name] | |
# inp = inp.reshape((224, 224, 3)) | |
img = Image.fromarray(np.uint8(inp)) | |
text = inferer(img, prompt, length, window_size=10, sample=sample) | |
return prompt, text | |
# return inp, prompt, text | |
''' | |
# test_run | |
sample_img = asarray(Image.open('../data/coco/images/sample.jpg')) | |
img, _, text = run(sample_img, 'There lies', 50, 20, sample=False) | |
print('test_run:', text) | |
''' | |
def launch(examples=None, title='Demo for ESPER', description=None, prompt_eg=None): | |
args = get_args() | |
ckpts = [p.parent / p.stem for p in Path(args.checkpoint).glob('*.ckpt')] | |
ckpts = {p.stem: p for p in ckpts} | |
inferers = {} | |
for model_name, ckpt in ckpts.items(): | |
ckpt_args = copy.deepcopy(args) | |
ckpt_args.checkpoint = str(ckpt) | |
inferer = prepare(ckpt_args) | |
inferers[model_name] = inferer | |
runner = Runner(inferers) | |
model_names = sorted(list(ckpts.keys())) | |
log.info(f'model_names: {model_names}') | |
examples = list(chain(*[[[n, *ex] for n in model_names] for ex in examples])) | |
iface = gr.Interface( | |
title=title, | |
description=description, | |
fn=runner.__call__, | |
inputs=[gr.components.Dropdown(choices=model_names, value=model_names[0], label='Backbone'), | |
gr.components.Image(shape=(224, 224), label='Image'), | |
gr.components.Textbox(label='Prompt', placeholder=prompt_eg), | |
gr.components.Slider(20, 40, step=1, label='Length'), | |
# gr.components.Slider(10, 100, step=1, label='window_size'), | |
gr.components.Checkbox(label='do sample')], | |
outputs=[gr.components.Textbox(label='Prompt'), | |
gr.components.Textbox(label='Generation')], | |
examples=examples | |
) | |
if args.port is not None: | |
print(f"running from {platform.node()}") | |
iface.launch( | |
server_name="0.0.0.0", | |
server_port=args.port | |
) | |
else: | |
iface.launch() | |
if __name__ == "__main__": | |
print(f"running from {platform.node()}") | |
launch() | |