#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys from typing import Callable import gradio as gr import huggingface_hub import numpy as np import PIL.Image from io import BytesIO sys.path.insert(0, 'animeganv2') import test1 as test from test1 import ImportGraph ORIGINAL_REPO_URL = 'https://github.com/TachibanaYoshino/AnimeGANv2' TITLE = 'TachibanaYoshino/AnimeGANv2' DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. """ ARTICLE = """ """ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def run( image, shinkai: ImportGraph, hayao: ImportGraph, paprika: ImportGraph, ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: im1 = shinkai.test('shinkai', image.name, True) #im2 = hayao.test('hayao', image.name, True) #im3 = paprika.test('paprika', image.name, True) return PIL.Image.open(im1),PIL.Image.open(im1),PIL.Image.open(im1) def main(): gr.close_all() args = parse_args() curPath = os.path.abspath(os.path.dirname(__file__)) #init shinkai = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Shinkai_weight')) #hayao = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Hayao_weight')) #paprika = ImportGraph(checkpoint_dir=os.path.join(curPath,'animeganv2/checkpoint/generator_Paprika_weight')) func = functools.partial(run, shinkai=shinkai ) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.Image(type='file', label='Input Image'), ], [ gr.outputs.Image( type='pil', label='Shinkai Result'), gr.outputs.Image( type='pil', label='Hayao Result'), gr.outputs.Image( type='pil', label='Paprika Result'), ], #examples=examples, theme=args.theme, title=TITLE, description=DESCRIPTION, article=ARTICLE, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()