File size: 2,741 Bytes
c5bd7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import numpy as np
import matplotlib.pyplot as plt
from mlxtend.evaluate import confusion_matrix
from mlxtend.plotting import plot_confusion_matrix
import model_builder

def plot_confusion_Matrix(model_path, dataloader, class_names, device, figsize=(12, 12)):
    """
    Generate and plot confusion matrix using mlxtend library from a PyTorch model and DataLoader.

    Args:
        model: PyTorch model's path eg(".pth" or ".pt").
        dataloader: DataLoader instance for the dataset.
        class_names (list): List of class names.
        device: Target device to compute on (e.g., "cuda" or "cpu").
        figsize (tuple): Figure size.

    Returns:
        None
    """

    # Load the model
    model = model_builder.TrashClassificationCNNModel(input_shape=3,
                                                      hidden_units=15,
                                                      output_shape=len(class_names)
                                                      )
    
    model.load_state_dict(torch.load(model_path))

    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y_true.extend(y.cpu().numpy())
            y_logit = model(X)
            y_pred.extend(torch.argmax(y_logit, dim=1).cpu().numpy())

    confmat = confusion_matrix(y_target=y_true, y_predicted=y_pred, binary=False)
    
    # Plot confusion matrix
    fig, ax = plot_confusion_matrix(conf_mat=confmat, 
                                    class_names=class_names,
                                    figsize=figsize)
    
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

def plot_metrics(metrics):
    """
    Plots training and testing loss and accuracy.

    Args:
        metrics (dict): A dictionary containing training and testing loss and accuracy.

    Returns:
        None
    """
    epochs = range(1, len(metrics['train_loss']) + 1)

    # Plot training and testing loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, metrics['train_loss'], 'b', label='Training loss')
    plt.plot(epochs, metrics['test_loss'], 'r', label='Testing loss')
    plt.title('Training and Testing Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    # Plot training and testing accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, metrics['train_acc'], 'b', label='Training accuracy')
    plt.plot(epochs, metrics['test_acc'], 'r', label='Testing accuracy')
    plt.title('Training and Testing Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # Show plot
    plt.tight_layout()
    plt.show()