File size: 4,200 Bytes
6974603
ca95568
6974603
 
 
ca95568
 
 
 
6974603
 
ca95568
 
 
 
 
 
 
6974603
 
ca95568
 
 
 
 
 
 
6974603
ca95568
 
 
 
 
6974603
 
 
ca95568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caeb1f4
 
 
ca95568
caeb1f4
c9267e5
ca95568
 
 
 
6974603
 
caeb1f4
 
 
ca95568
 
 
 
 
 
 
caeb1f4
 
 
ca95568
caeb1f4
ca95568
 
 
6974603
 
ca95568
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image
import gradio as gr

import aiohttp
import asyncio
from io import BytesIO

device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
                     model=AutoModelForImageClassification.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"),
                     feature_extractor=AutoFeatureExtractor.from_pretrained(
                         "carbon225/vit-base-patch16-224-hentai"),
                     device=device,
                     torch_dtype=dtype)


style_pipe = pipeline("image-classification",
                      model=AutoModelForImageClassification.from_pretrained(
                          "cafeai/cafe_style"),
                      feature_extractor=AutoFeatureExtractor.from_pretrained(
                          "cafeai/cafe_style"),
                      device=device,
                      torch_dtype=dtype)

aesthetic_pipe = pipeline("image-classification",
                          model=AutoModelForImageClassification.from_pretrained(
                              "cafeai/cafe_aesthetic"),
                          feature_extractor=AutoFeatureExtractor.from_pretrained(
                              "cafeai/cafe_aesthetic"),
                          device=device,
                          torch_dtype=dtype)


async def fetch_image(session, image_url):
    print(f"fetching image {image_url}")
    async with session.get(image_url) as response:
        if response.status == 200 and response.headers['content-type'].startswith('image'):
            pil_image = Image.open(BytesIO(await response.read())).convert('RGB')
            # resize image proportional
            # image = ImageOps.fit(image, (400, 400), Image.LANCZOS)

            return pil_image
    return None


async def fetch_images(image_urls):
    async with aiohttp.ClientSession() as session:
        tasks = [asyncio.ensure_future(fetch_image(
            session, image_url)) for image_url in image_urls]
        return await asyncio.gather(*tasks)


async def predict(json=None, enable_gallery=True, image=None, files=None):
    print(json)

    if image or files:
        if image is not None:
            images_paths = [image]
        elif files is not None:
            images_paths = list(map(lambda x: x.name, files))
        pil_images = [Image.open(image_path).convert("RGB")
                      for image_path in images_paths]
    elif json is not None:
        pil_images = await fetch_images(json["urls"])

    style = style_pipe(pil_images)
    aesthetic = aesthetic_pipe(pil_images)
    nsfw = nsfw_pipe(pil_images)
    results = [a + b + c for (a, b, c) in zip(style, aesthetic, nsfw)]
    label_data = {}
    if image is not None:
        label_data = {row["label"]: row["score"] for row in results[0]}

    return results, label_data, pil_images if enable_gallery else None


with gr.Blocks() as blocks:
    with gr.Row():
        with gr.Column():
            image = gr.Image(label="Image to test", type="filepath")
            files = gr.File(label="Multipls Images", file_types=[
                            "image"], file_count="multiple")
            enable_gallery = gr.Checkbox(label="Enable Gallery", value=True)
            json = gr.JSON(label="Results", value={"urls": [
                'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/b9fb3257-6a54-455e-b636-9d61cf261676.jpg',
                'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/062eb9be-76eb-4d7e-9299-d1ebea14b46f.jpg',
                'https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/8ff6d4f6-08d0-4a31-818c-4d32ab146f81.jpg']})
        with gr.Column():
            label = gr.Label(label="style")
            results = gr.JSON(label="Results")
            gallery = gr.Gallery().style(grid=[2], height="auto")
    btn = gr.Button("Run")

    btn.click(fn=predict, inputs=[json, enable_gallery, image, files],
              outputs=[results, label, gallery], api_name="inference")

blocks.queue()
blocks.launch(debug=True, inline=True)