|
""" |
|
Contains various utility functions for PyTorch model training and saving. |
|
""" |
|
import torch |
|
from pathlib import Path |
|
|
|
def save_model(model: torch.nn.Module, |
|
target_dir: str, |
|
model_name: str): |
|
"""Saves a PyTorch model to a target directory. |
|
|
|
Args: |
|
model: A target PyTorch model to save. |
|
target_dir: A directory for saving the model to. |
|
model_name: A filename for the saved model. Should include |
|
either ".pth" or ".pt" as the file extension. |
|
|
|
Example usage: |
|
save_model(model=model_0, |
|
target_dir="models", |
|
model_name="05_going_modular_tingvgg_model.pth") |
|
""" |
|
|
|
target_dir_path = Path(target_dir) |
|
target_dir_path.mkdir(parents=True, |
|
exist_ok=True) |
|
|
|
|
|
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" |
|
model_save_path = target_dir_path / model_name |
|
|
|
|
|
print(f"[INFO] Saving model to: {model_save_path}") |
|
torch.save(obj=model.state_dict(), |
|
f=model_save_path) |