Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 2,526 Bytes
6974603 caeb1f4 6974603 caeb1f4 6974603 caeb1f4 6974603 caeb1f4 6974603 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, pipeline
import torch
from PIL import Image
import gradio as gr
device = "cuda:0" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
nsfw_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
feature_extractor=AutoFeatureExtractor.from_pretrained("carbon225/vit-base-patch16-224-hentai"),
device=device,
torch_dtype=dtype)
style_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_style"),
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_style"),
device=device,
torch_dtype=dtype)
aesthetic_pipe = pipeline("image-classification",
model= AutoModelForImageClassification.from_pretrained("cafeai/cafe_aesthetic"),
feature_extractor=AutoFeatureExtractor.from_pretrained("cafeai/cafe_aesthetic"),
device=device,
torch_dtype=dtype)
def predict(image, files=None):
print(image, files)
images_paths = [image]
if not files == None:
images_paths = list(map(lambda x: x.name, files))
pil_images = [Image.open(image_path).convert("RGB") for image_path in images_paths]
style = style_pipe(pil_images)
aesthetic = aesthetic_pipe(pil_images)
nsfw = nsfw_pipe(pil_images)
results = [ a + b + c for (a,b,c) in zip(style, aesthetic, nsfw)]
label_data = {}
if image:
label_data = [{ row["label"]:row["score"] for row in image } for image in results]
return label_data[0], results
with gr.Blocks() as blocks:
with gr.Row():
with gr.Column():
image = gr.Image(label="Image to test", type="filepath")
files = gr.File(label="Multipls Images", file_types=["image"], file_count="multiple")
with gr.Column():
label = gr.Label(label="style")
results = gr.JSON(label="Results")
# gallery = gr.Gallery().style(grid=[2], height="auto")
btn = gr.Button("Run")
btn.click(fn=predict, inputs=[image, files], outputs=[label, results], api_name="inference")
blocks.queue()
blocks.launch(debug=True,inline=True) |