TriplaneGaussian / utils.py
zouzx's picture
update gradio theme
c839178
import os
import time
import cv2
import numpy as np
import torch
from PIL import Image
from rembg import remove
from segment_anything import SamPredictor, sam_model_registry
import urllib.request
from tqdm import tqdm
def sam_init(sam_checkpoint, device_id=0):
model_type = "vit_h"
device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
predictor = SamPredictor(sam)
return predictor
def sam_out_nosave(predictor, input_image, *bbox_sliders):
bbox = np.array(bbox_sliders)
image = np.asarray(input_image)
predictor.set_image(image)
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
box=bbox, multimask_output=True
)
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
out_image[:, :, :3] = image
out_image_bbox = out_image.copy()
out_image_bbox[:, :, 3] = (
masks_bbox[-1].astype(np.uint8) * 255
) # np.argmax(scores_bbox)
torch.cuda.empty_cache()
return Image.fromarray(out_image_bbox, mode="RGBA")
# contrast correction, rescale and recenter
def image_preprocess(input_image, save_path, lower_contrast=True, rescale=True):
image_arr = np.array(input_image)
in_w, in_h = image_arr.shape[:2]
if lower_contrast:
alpha = 0.8 # Contrast control (1.0-3.0)
beta = 0 # Brightness control (0-100)
# Apply the contrast adjustment
image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
image_arr[image_arr[..., -1] > 200, -1] = 255
ret, mask = cv2.threshold(
np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
ratio = 0.75
if rescale:
side_len = int(max_size / ratio)
else:
side_len = in_w
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
] = image_arr[y : y + h, x : x + w]
rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS)
rgba.save(save_path)
def pred_bbox(image):
image_nobg = remove(image.convert("RGBA"), alpha_matting=True)
alpha = np.asarray(image_nobg)[:, :, -1]
x_nonzero = np.nonzero(alpha.sum(axis=0))
y_nonzero = np.nonzero(alpha.sum(axis=1))
x_min = int(x_nonzero[0].min())
y_min = int(y_nonzero[0].min())
x_max = int(x_nonzero[0].max())
y_max = int(y_nonzero[0].max())
return x_min, y_min, x_max, y_max
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, *args, **kwargs):
if isinstance(vars, list):
return [wrapper(x, *args, **kwargs) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, *args, **kwargs) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
else:
return func(vars, *args, **kwargs)
return wrapper
@make_recursive_func
def todevice(vars, device="cuda"):
if isinstance(vars, torch.Tensor):
return vars.to(device)
elif isinstance(vars, str):
return vars
elif isinstance(vars, bool):
return vars
elif isinstance(vars, float):
return vars
elif isinstance(vars, int):
return vars
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))