justsomerandomdude264's picture
Initial commit
c5bd7aa
raw
history blame
No virus
2.74 kB
import torch
import numpy as np
import matplotlib.pyplot as plt
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix
import model_builder
def plot_confusion_Matrix(model_path, dataloader, class_names, device, figsize=(12, 12)):
"""
Generate and plot confusion matrix using mlxtend library from a PyTorch model and DataLoader.
Args:
model: PyTorch model's path eg(".pth" or ".pt").
dataloader: DataLoader instance for the dataset.
class_names (list): List of class names.
device: Target device to compute on (e.g., "cuda" or "cpu").
figsize (tuple): Figure size.
Returns:
None
"""
# Load the model
model = model_builder.TrashClassificationCNNModel(input_shape=3,
hidden_units=15,
output_shape=len(class_names)
)
model.load_state_dict(torch.load(model_path))
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y_true.extend(y.cpu().numpy())
y_logit = model(X)
y_pred.extend(torch.argmax(y_logit, dim=1).cpu().numpy())
confmat = confusion_matrix(y_target=y_true, y_predicted=y_pred, binary=False)
# Plot confusion matrix
fig, ax = plot_confusion_matrix(conf_mat=confmat,
class_names=class_names,
figsize=figsize)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
def plot_metrics(metrics):
"""
Plots training and testing loss and accuracy.
Args:
metrics (dict): A dictionary containing training and testing loss and accuracy.
Returns:
None
"""
epochs = range(1, len(metrics['train_loss']) + 1)
# Plot training and testing loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, metrics['train_loss'], 'b', label='Training loss')
plt.plot(epochs, metrics['test_loss'], 'r', label='Testing loss')
plt.title('Training and Testing Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# Plot training and testing accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, metrics['train_acc'], 'b', label='Training accuracy')
plt.plot(epochs, metrics['test_acc'], 'r', label='Testing accuracy')
plt.title('Training and Testing Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
# Show plot
plt.tight_layout()
plt.show()