vierundvi / handler.py
mart9992's picture
m
b93bdbf
raw
history blame
819 Bytes
import os
import subprocess
import torch
import requests
from PIL import Image
from io import BytesIO
from test import just_get_sd_mask
print(os.listdir('/usr/local/'))
print(torch.version.cuda)
class EndpointHandler():
def __init__(self, path="."):
pass
def __call__(self, data):
mask_pil = just_get_sd_mask(Image.open("assets/demo1.jpg"), "bear", 10)
if mask_pil.mode != 'RGB':
mask_pil = mask_pil.convert('RGB')
# Convert PIL image to byte array
img_byte_arr = BytesIO()
mask_pil.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
# Upload to file.io
response = requests.post("https://file.io/", files={"file": img_byte_arr})
url = response.json().get('link')
return {"url": url}