File size: 287 Bytes
4f6613a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from contextlib import nullcontext

import torch


def autocast_exclude_mps(
    device_type: str, dtype: torch.dtype
) -> nullcontext | torch.autocast:
    return (
        nullcontext()
        if torch.backends.mps.is_available()
        else torch.autocast(device_type, dtype)
    )