File size: 433 Bytes
128757a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from contextlib import contextmanager

@contextmanager
def nullcontext(enter_result=None, **kwargs):
    yield enter_result

try:
    from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
except:
    print('[Warning] Library for automatic mixed precision is not found, AMP is disabled!!')
    GradScaler = nullcontext
    autocast = nullcontext
    custom_fwd = nullcontext
    custom_bwd = nullcontext