Spaces:
Runtime error
Runtime error
import gc | |
import inspect | |
import threading | |
from functools import wraps | |
import torch | |
from ia_check_versions import ia_check_versions | |
model_access_sem = threading.Semaphore(1) | |
def torch_gc(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
if ia_check_versions.torch_mps_is_available: | |
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): | |
torch.mps.empty_cache() | |
def clear_cache(): | |
gc.collect() | |
torch_gc() | |
def post_clear_cache(sem): | |
with sem: | |
gc.collect() | |
torch_gc() | |
def async_post_clear_cache(): | |
thread = threading.Thread(target=post_clear_cache, args=(model_access_sem,)) | |
thread.start() | |
def clear_cache_decorator(func): | |
def yield_wrapper(*args, **kwargs): | |
clear_cache() | |
yield from func(*args, **kwargs) | |
clear_cache() | |
def wrapper(*args, **kwargs): | |
clear_cache() | |
res = func(*args, **kwargs) | |
clear_cache() | |
return res | |
if inspect.isgeneratorfunction(func): | |
return yield_wrapper | |
else: | |
return wrapper | |