crowdcounting / app.py
Shad0ws's picture
Update app.py
fa5d33a
raw
history blame contribute delete
No virus
1.64 kB
import numpy
import torch
import gradio as gr
from einops import rearrange
from torchvision import transforms
from model import CANNet
model = CANNet()
checkpoint = torch.load('part_B_pre.pth.tar',map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model.eval()
## Defining the transform function
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
def crowd(img):
## Transforming the image
img = transform(img)
## Adding batch dimension
img = rearrange(img, "c h w -> 1 c h w")
## Slicing the image into four parts
h = img.shape[2]
w = img.shape[3]
h_d = int(h/2)
w_d = int(w/2)
img_1 = img[:,:,:h_d,:w_d]
img_2 = img[:,:,:h_d,w_d:]
img_3 = img[:,:,h_d:,:w_d]
img_4 = img[:,:,h_d:,w_d:]
## Inputting the 4 images into the model, converting it to numpy array, and summing to get the density
with torch.no_grad():
density_1 = model(img_1).numpy().sum()
density_2 = model(img_2).numpy().sum()
density_3 = model(img_3).numpy().sum()
density_4 = model(img_4).numpy().sum()
## Summing up the estimated density and rounding the result to get an integer
pred = density_1 + density_2 + density_3 + density_4
pred = int(pred.round())
return pred
outputs = gr.outputs.Textbox(type="text", label="Estimated crowd density:")
inputs = gr.inputs.Image(type="numpy", label="Input the image here:")
gr.Interface(fn=crowd, inputs=inputs, outputs=outputs, allow_flagging="never").launch(inbrowser=True)