artline / app.py
hylee's picture
init
6a6cd09
raw
history blame
3.03 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import sys
from typing import Callable
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
from io import BytesIO
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import load_learner
from core import FeatureLoss
import torchvision.transforms as T
ORIGINAL_REPO_URL = 'https://github.com/vijishmadhavan/ArtLine'
TITLE = 'vijishmadhavan/ArtLine'
DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
"""
ARTICLE = """
"""
MODEL_REPO = 'hylee/artline_model'
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
def load_model():
dir = 'model'
name = 'ArtLine_650.pkl'
model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
name,
cache_dir=dir,
force_filename=name)
return model_path
def run(
image,
learn,
) -> tuple[PIL.Image.Image]:
img = PIL.Image.open(image.name)
img_t = T.ToTensor()(img)
img_fast = Image(img_t)
p, img_hr, b = learn.predict(img_fast)
r = np.uint8(np.clip(image2np(img_hr), 0, 1) * 255)
return PIL.Image.fromarray(r)
learn = None
def main():
gr.close_all()
args = parse_args()
model_path = load_model()
# singleton start
def load_pkl(self) -> Any:
global learn
path = Path("model")
learn = load_learner(path, 'ArtLine_650.pkl')
PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
pl = PklLoader()
pl.load_pkl()
func = functools.partial(run, learn=learn)
func = functools.update_wrapper(func, run)
gr.Interface(
func,
[
gr.inputs.Image(type='file', label='Input Image'),
],
[
gr.outputs.Image(
type='pil',
label='Result'),
],
#examples=examples,
theme=args.theme,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()