justsomerandomdude264's picture
Initial commit
c5bd7aa
raw
history blame
No virus
3.08 kB
import torch
import data_setup, model_builder, engine, utils, plotting
from torchvision import transforms
import argparse
# To Avoid Glitches related to CUDA
def set_memory_limit():
if torch.cuda.is_available():
try:
torch.tensor([1], device='cuda') # Adjust memory fraction as needed
print(f"Device is GPU/CUDA.")
device = 'cuda'
return device
except:
print("Device is CPU.")
device = 'cpu'
return device
# Define argument parsing directly
parser = argparse.ArgumentParser(description="Train a model for Classification of types of Trash.")
parser.add_argument("--train_dir", type=str, default="data/train", help="Directory containing training images")
parser.add_argument("--test_dir", type=str, default="data/test", help="Directory containing testing images")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Learning rate for training")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
parser.add_argument("--num_epochs", type=int, default=20, help="Number of epochs to train for")
args = parser.parse_args()
# Set Args up in Variables
train_dir = args.train_dir
test_dir = args.test_dir
LEARNING_RATE = args.learning_rate
BATCH_SIZE = args.batch_size
NUM_EPOCHS = args.num_epochs
HIDDEN_UNITS = 15
# Data transformation
data_transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor()
])
# Create DataLoaders
train_dataloader, test_dataloader, class_names = data_setup.train_test_dataloader(
train_dir=train_dir,
test_dir=test_dir,
transform=data_transform,
batch_size=BATCH_SIZE
)
# Model creation
device = set_memory_limit()
model = model_builder.TrashClassificationCNNModel(input_shape=3,
hidden_units=HIDDEN_UNITS,
output_shape=len(class_names)
).to(device)
# Loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=LEARNING_RATE)
# Start training
metrics = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=NUM_EPOCHS,
device=device)
# Save the model
utils.save_model(model=model,
target_dir="models",
model_name="Trash_Classification_Model_COLOURED.pth")
# Clear CUDA cache
torch.cuda.empty_cache()
# Plot the Confusion Matrix
plotting.plot_confusion_Matrix(model_path="models\Trash_Classification_Model_COLOURED.pth",
dataloader=test_dataloader,
class_names=class_names,
device=device)
# Plot the Loss and Accuracy
plotting.plot_metrics(metrics)