SadTalker / src /utils /face_enhancer.py
nijisakai's picture
Duplicate from vinthony/SadTalker
585c7ea
raw
history blame
4.44 kB
import os
import torch
from gfpgan import GFPGANer
from tqdm import tqdm
from src.utils.videoio import load_video_to_cv2
import cv2
class GeneratorWithLen(object):
""" From https://stackoverflow.com/a/7460929 """
def __init__(self, gen, length):
self.gen = gen
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
return list(gen)
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator with a __len__ method so that it can passed to functions that
call len()"""
if os.path.isfile(images): # handle video to images
# TODO: Create a generator version of load_video_to_cv2
images = load_video_to_cv2(images)
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
gen_with_len = GeneratorWithLen(gen, len(images))
return gen_with_len
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator function so that all of the enhanced images don't need
to be stored in memory at the same time. This can save tons of RAM compared to
the enhancer function. """
print('face enhancer....')
if not isinstance(images, list) and os.path.isfile(images): # handle video to images
images = load_video_to_cv2(images)
# ------------------------ set up GFPGAN restorer ------------------------
if method == 'gfpgan':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif method == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
elif method == 'codeformer': # TODO:
arch = 'CodeFormer'
channel_multiplier = 2
model_name = 'CodeFormer'
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
else:
raise ValueError(f'Wrong model version {method}.')
# ------------------------ set up background upsampler ------------------------
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# determine model paths
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('checkpoints', model_name + '.pth')
if not os.path.isfile(model_path):
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
# ------------------------ restore ------------------------
for idx in tqdm(range(len(images)), 'Face Enhancer:'):
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
# restore faces and background if necessary
cropped_faces, restored_faces, r_img = restorer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True)
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
yield r_img