File size: 1,779 Bytes
2c83deb
d70b1c6
2c83deb
 
b4614a6
2c83deb
 
26fde58
 
 
 
 
 
2c83deb
 
 
 
 
2c745bf
 
2c83deb
 
 
2c745bf
2c83deb
d4c5cc1
b4614a6
d4c5cc1
2c83deb
 
be8f6aa
2c83deb
 
4fa2b44
2c83deb
 
 
be8f6aa
 
 
2c83deb
01e3be6
 
 
 
 
b4614a6
01e3be6
e4c9997
b4614a6
22d7300
ba24c0a
 
 
22d7300
2c83deb
d188461
ba24c0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import  Dict, List, Any
from PIL import Image
import torch
from torch import autocast
from diffusers import StableDiffusionUpscalePipeline
import base64
from io import BytesIO
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger("transformers")
logger.info("INFO")
logger.warning("WARN")


# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#if device.type != 'cuda':
    #raise ValueError("need to run on GPU")

class EndpointHandler():
    def __init__(self, path=""):
        self.path = path
        # load the optimized model
        #model_id = "stabilityai/stable-diffusion-x4-upscaler"
        self.pipe = StableDiffusionUpscalePipeline.from_pretrained(path, torch_dtype=torch.float16)
        self.pipe = self.pipe.to(device)


    def __call__(self, data) -> List[Dict[str, Any]]:
        """
        Args:
            image (:obj:`string`)
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        logger.info('data received %s', data)
        inputs = data.get("inputs")
        logger.info('inputs received %s', inputs)

        image_base64 = base64.b64decode(inputs['image'])
        logger.info('image_base64')
        image_bytes = BytesIO(image_base64)
        logger.info('image_bytes')
        image = Image.open(image_bytes)
        prompt = inputs['prompt']
        logger.info('image')
        with autocast(device.type):
            upscaled_image = self.pipe(prompt, image).images[0]

        buffered = BytesIO()
        upscaled_image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue())

        # postprocess the prediction
        return {"image": img_str.decode()}
        #return {"image": "test"}