language: en
tags:
- image-classification
- CNN
- Convolution Neural Entwork
- Nueral Network
- Trash
metrics:
- name: train-accuracy
value: 91%
- name: test-accuracy
value: 55%
pipeline:
- image-classification
libraries:
- name: torch
version: 1.9.0
- name: torchvision
version: 0.10.0
- name: numpy
version: 1.21.0
Trash Classification CNN Model
About
This project is a convolutional neural network (CNN) model developed for the purpose of classifying different types of trash items.
The CNN model in this project utilizes the TinyVGG architecture, a compact version of the popular VGG neural network architecture. The model is trained to classify trash items into the following subcategories:
- Cardboard
- Food Organics
- Glass
- Metal
- Miscellaneous Trash
- Paper
- Plastic
- Textile Trash
- Vegetation
In total, there are 9 categories into which the trash items are classified.
For more details about the CNN architecture used in this project, you can refer to the CNN Explainer website.
Info
Only 30% of the data from the Real Trash Dataset has been used and divided into an 80%-20% split of Train and Test.
The Huggingface Repository contains 7 files found in the files and versions
tab:
data_setup.py: This file contains functions for setting up the data into datasets using ImageFolder and then turning it into batches using DataLoader. It also returns the names of the classes.
model_builder.py: This file contains a class which subclasses nn.Module and replicates the TinyVGG CNN model architecture with a few modifications here and there.
engine.py: This file contains three functions:
train_step
,test_step
, andtrain
. The previous two are used to train and test the model, respectively, and the last one integrates both to train the model.plotting.py: This file contains functions to plot metrics like loss and accuracy using
plot_metrics
, and it also has a functionplot_confusion_Matrix
to plot the confusion matrix.predict.py: This file can be run with
--image
and--model_path
arguments to get the prediction of the model on the specified image path.utils.py: This file contains functions to save the model in a specific folder with a changeable name.
train.py: This script uses all the files except
predict.py
and can take argument flags to change hyperparameters. It can be run with the following arguments:python train.py --train_dir TRAIN_DIR --test_dir TEST_DIR --learning_rate LEARNING_RATE --batch_size BATCH_SIZE --num_epochs NUM_EPOCHS
Additionally, it is device agnostic, meaning it automatically utilizes available resources regardless of the specific device used.
Additionally, the repository contains 2 folders:
data: This stores the data and has subdirectories train and test.
models: This stores the model saved by utils.py.
samples: This has 10 pictures, you can use for testing the model using
predict.py
.
Model Overview
This model is designed for image classification tasks. It requires input images of size 112x112 pixels. Containing 2 blocks with 2 convulutional layers and then a flattner with a classfier.
The architecture looks like :
TrashClassificationCNNModel(
(block_1): Sequential(
(0): Conv2d(3, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(15, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(block_2): Sequential(
(0): Conv2d(15, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): Conv2d(15, 15, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU()
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=11760, out_features=9, bias=True)
)
)
Dataset Overview
The dataset used containes images of multiple waste items with multiple classes named RealWaste. It has 4752 samples.
- Source: Click here
- Citation: Single,Sam, Iranmanesh,Saeid, and Raad,Raad. (2023). RealWaste. UCI Machine Learning Repository. https://doi.org/10.24432/C5SS4G.
Discliamer
The model mught give inaccurate or wrong results.