|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from mlxtend.evaluate import confusion_matrix |
|
from mlxtend.plotting import plot_confusion_matrix |
|
import model_builder |
|
|
|
def plot_confusion_Matrix(model_path, dataloader, class_names, device, figsize=(12, 12)): |
|
""" |
|
Generate and plot confusion matrix using mlxtend library from a PyTorch model and DataLoader. |
|
|
|
Args: |
|
model: PyTorch model's path eg(".pth" or ".pt"). |
|
dataloader: DataLoader instance for the dataset. |
|
class_names (list): List of class names. |
|
device: Target device to compute on (e.g., "cuda" or "cpu"). |
|
figsize (tuple): Figure size. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
|
|
model = model_builder.TrashClassificationCNNModel(input_shape=3, |
|
hidden_units=15, |
|
output_shape=len(class_names) |
|
) |
|
|
|
model.load_state_dict(torch.load(model_path)) |
|
|
|
model.eval() |
|
y_true = [] |
|
y_pred = [] |
|
|
|
with torch.no_grad(): |
|
for X, y in dataloader: |
|
X = X.to(device) |
|
y_true.extend(y.cpu().numpy()) |
|
y_logit = model(X) |
|
y_pred.extend(torch.argmax(y_logit, dim=1).cpu().numpy()) |
|
|
|
confmat = confusion_matrix(y_target=y_true, y_predicted=y_pred, binary=False) |
|
|
|
|
|
fig, ax = plot_confusion_matrix(conf_mat=confmat, |
|
class_names=class_names, |
|
figsize=figsize) |
|
|
|
plt.title('Confusion Matrix') |
|
plt.xlabel('Predicted Label') |
|
plt.ylabel('True Label') |
|
plt.show() |
|
|
|
def plot_metrics(metrics): |
|
""" |
|
Plots training and testing loss and accuracy. |
|
|
|
Args: |
|
metrics (dict): A dictionary containing training and testing loss and accuracy. |
|
|
|
Returns: |
|
None |
|
""" |
|
epochs = range(1, len(metrics['train_loss']) + 1) |
|
|
|
|
|
plt.figure(figsize=(12, 5)) |
|
plt.subplot(1, 2, 1) |
|
plt.plot(epochs, metrics['train_loss'], 'b', label='Training loss') |
|
plt.plot(epochs, metrics['test_loss'], 'r', label='Testing loss') |
|
plt.title('Training and Testing Loss') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('Loss') |
|
plt.legend() |
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
plt.plot(epochs, metrics['train_acc'], 'b', label='Training accuracy') |
|
plt.plot(epochs, metrics['test_acc'], 'r', label='Testing accuracy') |
|
plt.title('Training and Testing Accuracy') |
|
plt.xlabel('Epochs') |
|
plt.ylabel('Accuracy') |
|
plt.legend() |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|