Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,14 +2,18 @@ from fastai.vision.core import PILImageBW, TensorImageBW
|
|
2 |
from datasets import ClassLabel
|
3 |
import gradio as gr
|
4 |
from fastai.learner import load_learner
|
|
|
|
|
5 |
|
6 |
def get_image_attr(x): return x['image']
|
7 |
def get_target_attr(x): return x['target']
|
|
|
8 |
|
9 |
def img2tensor(im: Image.Image):
|
10 |
return TensorImageBW(array(im)).unsqueeze(0)
|
11 |
|
12 |
classLabel = ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)
|
|
|
13 |
|
14 |
def add_target(x:dict):
|
15 |
x['target'] = classLabel.int2str(x['label'])
|
@@ -20,13 +24,19 @@ learn = load_learner('export.pkl', cpu=True)
|
|
20 |
def classify(inp):
|
21 |
img = PILImageBW.create(inp)
|
22 |
item = dict(image=img)
|
23 |
-
pred, _,
|
24 |
-
return
|
|
|
|
|
|
|
|
|
25 |
|
26 |
iface = gr.Interface(
|
27 |
fn=classify,
|
28 |
-
inputs=gr.inputs.Image(),
|
29 |
-
outputs=
|
30 |
title="Fashion Mnist Classifier",
|
31 |
description="fastai deployment in Gradio.",
|
|
|
|
|
32 |
).launch()
|
|
|
2 |
from datasets import ClassLabel
|
3 |
import gradio as gr
|
4 |
from fastai.learner import load_learner
|
5 |
+
from PIL import Image
|
6 |
+
from numpy import array
|
7 |
|
8 |
def get_image_attr(x): return x['image']
|
9 |
def get_target_attr(x): return x['target']
|
10 |
+
def get_label_attr(x): return x['label']
|
11 |
|
12 |
def img2tensor(im: Image.Image):
|
13 |
return TensorImageBW(array(im)).unsqueeze(0)
|
14 |
|
15 |
classLabel = ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)
|
16 |
+
labels = classLabel.names
|
17 |
|
18 |
def add_target(x:dict):
|
19 |
x['target'] = classLabel.int2str(x['label'])
|
|
|
24 |
def classify(inp):
|
25 |
img = PILImageBW.create(inp)
|
26 |
item = dict(image=img)
|
27 |
+
pred, _, prob = learn.predict(item)
|
28 |
+
return {label: float(prob[i]) for i, label in enumerate(labels)}
|
29 |
+
# return classLabel.int2str(int(pred))
|
30 |
+
|
31 |
+
examples = ['shoes.jpg', 't-shirt.jpg']
|
32 |
+
interpretation='default'
|
33 |
|
34 |
iface = gr.Interface(
|
35 |
fn=classify,
|
36 |
+
inputs=gr.inputs.Image(image_mode='L'),
|
37 |
+
outputs=gr.outputs.Label(num_top_classes=3),
|
38 |
title="Fashion Mnist Classifier",
|
39 |
description="fastai deployment in Gradio.",
|
40 |
+
examples=examples,
|
41 |
+
interpretation=interpretation,
|
42 |
).launch()
|